Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop interleave placement in QKV matrix #1013

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
bff05e4
Drop interleaved placement in QKV
Andrei-Aksionov Mar 2, 2024
0ed697f
Update test for test_llama2_70b_conversion
Andrei-Aksionov Mar 5, 2024
8e46dd5
Merge branch 'main' into qkv_drop_interleave_placement
Andrei-Aksionov Mar 5, 2024
b286470
Correct shapes for KV-cache
Andrei-Aksionov Mar 7, 2024
60f2c93
Always do .repeat_interleave for grouped queries in training mode.
Andrei-Aksionov Mar 7, 2024
0f41eb4
Test_convert_hf_checkpoint: test all branches
Andrei-Aksionov Mar 8, 2024
e25fb2c
test_convert_lit_checkpoint check all branches
Andrei-Aksionov Mar 8, 2024
61c3265
qkv_reassemble instead of qkv_split
Andrei-Aksionov Mar 8, 2024
0e8b18f
convert_lit: test for qkv_reassemble
Andrei-Aksionov Mar 8, 2024
c2323f7
Merge branch 'main' into qkv_drop_interleave_placement
Andrei-Aksionov Mar 8, 2024
089fc5c
Merge branch 'main' into qkv_drop_interleave_placement
Andrei-Aksionov Mar 15, 2024
a44b83d
Merge branch 'main' into qkv_drop_interleave_placement
Andrei-Aksionov Nov 17, 2024
b636a58
Fix the test
Andrei-Aksionov Nov 17, 2024
4371c06
Handle legacy checkpoints
Andrei-Aksionov Nov 17, 2024
465a9f7
attn.attn --> attn.qkv
Andrei-Aksionov Nov 18, 2024
5a48af1
Remove accidentally added files
Andrei-Aksionov Nov 18, 2024
311c2c5
Cleaner version of load_state_dict for legacy checkpoints
Andrei-Aksionov Nov 19, 2024
5520aef
Add note that SDPA is disabled for non None mask or softcapping
Andrei-Aksionov Dec 26, 2024
ac310a9
Merge branch 'main' into qkv_drop_interleave_placement
Andrei-Aksionov Dec 26, 2024
b7d82aa
Align the code with non-interleaved placement of QKV
Andrei-Aksionov Dec 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion litgpt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def scaled_dot_product_attention(
ak, av = self.adapter_kv_cache
else:
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
aqkv = self.attn(prefix)
aqkv = self.qkv(prefix)
q_per_kv = self.config.n_head // self.config.n_query_groups
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
aqkv = aqkv.permute(0, 2, 3, 1, 4)
Expand Down
16 changes: 12 additions & 4 deletions litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
from litgpt.adapter import Config as BaseConfig
from litgpt.model import KVCache
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
from litgpt.utils import map_old_state_dict_weights


Expand Down Expand Up @@ -163,7 +164,7 @@ def __init__(self, config: Config, block_idx: int) -> None:
nn.Module.__init__(self)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)
self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)
# output projection
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
Expand All @@ -186,17 +187,24 @@ def __init__(self, config: Config, block_idx: int) -> None:
self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
"""For compatibility with base and/or legacy checkpoints."""
mapping = {
"attn.weight": "attn.linear.weight",
"attn.bias": "attn.linear.bias",
"qkv.weight": "qkv.linear.weight",
"qkv.bias": "qkv.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
# For compatibility with older checkpoints
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)

for attr in ("weight", "bias"):
legacy_key = f"{prefix}attn.linear.{attr}"
current_key = f"{prefix}qkv.linear.{attr}"
if legacy_key in state_dict:
state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)

super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


Expand Down
13 changes: 6 additions & 7 deletions litgpt/generate/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,30 @@
import logging
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional, Union
import warnings

import lightning as L
from lightning_utilities.core.imports import RequirementCache
import torch
import torch._dynamo.config
import torch._inductor.config
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.utilities import rank_zero_only
from lightning_utilities.core.imports import RequirementCache

import litgpt.generate.base as generate_base
from litgpt.model import GPT
from litgpt.config import Config
from litgpt.tokenizer import Tokenizer
from litgpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE
from litgpt.model import GPT, CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
check_nvlink_connectivity,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision
get_default_supported_precision,
)


Expand Down Expand Up @@ -71,7 +70,7 @@ def tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMA


def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None:
tensor_parallel_linear(fabric, attn.attn, "colwise")
tensor_parallel_linear(fabric, attn.qkv, "colwise")
tensor_parallel_linear(fabric, attn.proj, "rowwise")
attn.register_forward_hook(partial(all_reduce_output, fabric.world_size))

Expand Down
63 changes: 25 additions & 38 deletions litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from litgpt.model import Block as BaseBlock
from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention
from litgpt.model import KVCache
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
from litgpt.utils import map_old_state_dict_weights


Expand Down Expand Up @@ -267,18 +268,14 @@ def lora_ind(self) -> torch.Tensor:
# Indices are needed to properly pad weight updates with zeros.
if not hasattr(self, "_lora_ind"):
enable_q, enable_k, enable_v = self.enable_lora
qkv_group_size = self.n_head // self.n_query_groups + 2
candidate_indices = range(self.linear.out_features)
kv_embd_size = self.linear.in_features // (self.n_head // self.n_query_groups)
lora_ind = []
if enable_q:
q_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size < qkv_group_size - 2]
lora_ind.extend(q_ind)
lora_ind.extend(range(0, self.linear.in_features))
if enable_k:
k_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 2]
lora_ind.extend(k_ind)
lora_ind.extend(range(self.linear.in_features, self.linear.in_features + kv_embd_size))
if enable_v:
v_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 1]
lora_ind.extend(v_ind)
lora_ind.extend(range(self.linear.in_features + kv_embd_size, self.linear.out_features))
self.register_buffer(
"_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False
)
Expand All @@ -298,27 +295,6 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
________________________________________
| query | key | value |
----------------------------------------
For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped
queries are adjacent to their associated key and value weights.
For example, suppose we have n_head = 12 with 3 query groups.
Then along the embedding dimension the interleaved weights would look like

[Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V],

where each Q, K, and V has size head_size.

In this case, the previously-described weight update applies separately to each
individual block, so the update will take the form

[[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...],
[.............................................................................],
[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]]
↑ ↑ ↑ ↑ ↑ ↑
________________________________________________________________________________
| q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ...
--------------------------------------------------------------------------------
Note that in the above diagram, the size of each q block will equal q_per_kv
times the size of each k and v block.

Args:
x: tensor with weights update that will be padded with zeros if necessary
Expand Down Expand Up @@ -391,7 +367,9 @@ def get_lora_AB(self) -> torch.Tensor:
lora = self.conv1d(
self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
).squeeze(
0
) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
return self.zero_pad(lora.T * self.scaling).T # (256, 128) after zero_pad (384, 128)

def merge(self) -> None:
Expand Down Expand Up @@ -430,7 +408,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
after_B = self.conv1d(
after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
).transpose(
-2, -1
) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
return pretrained + lora

Expand Down Expand Up @@ -602,7 +582,7 @@ def __init__(self, config: Config, block_idx: int) -> None:
nn.Module.__init__(self)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = LoRAQKVLinear(
self.qkv = LoRAQKVLinear(
in_features=config.n_embd,
out_features=shape,
r=config.lora_r,
Expand All @@ -628,21 +608,28 @@ def __init__(self, config: Config, block_idx: int) -> None:
# disabled by default
self.kv_cache: Optional[KVCache] = None
self.apply_sliding_window_attention = (
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_stride == 0
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_stride == 0
)

self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
"""For compatibility with base and/or legacy checkpoints."""
mapping = {
"attn.weight": "attn.linear.weight",
"attn.bias": "attn.linear.bias",
"qkv.weight": "qkv.linear.weight",
"qkv.bias": "qkv.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)

for attr in ("weight", "bias"):
legacy_key = f"{prefix}attn.linear.{attr}"
current_key = f"{prefix}qkv.linear.{attr}"
if legacy_key in state_dict:
state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)

super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


Expand Down Expand Up @@ -758,4 +745,4 @@ def merge_lora_weights(model: GPT) -> None:
"""Merge LoRA weights into the full-rank weights to speed up inference."""
for module in model.modules():
if isinstance(module, LoRALinear):
module.merge()
module.merge()
Loading
Loading