-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainable Tokens: Support for Weight Tying #2399
base: main
Are you sure you want to change the base?
Changes from all commits
ac70db6
a730112
b7b23b1
7d2a715
66b8078
8865631
f84b370
2ce8042
2a5b33e
88a0cd9
b557bcd
26cf2f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -275,12 +275,12 @@ class LoraConfig(PeftConfig): | |
The core module from Megatron to use, defaults to `"megatron.core"`. | ||
trainable_token_indices (`Optional[Union[List[int], dict[str, List[int]]]]`) | ||
Lets you specify which token indices to selectively fine-tune without requiring to re-train the whole | ||
embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list of indices | ||
which will then target the `embed_tokens` layer, or, if your model is using a different layer for | ||
embedding, you can specify a dictionary where the key is the name of the embedding module and the values | ||
are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. Note that training with FSDP/DeepSpeed | ||
might not yet be fully supported with this option enabled. Also note that models using weight-tying are | ||
currently not supported. | ||
embedding matrix using the `peft.TrainableTokensModel` method. You can specify token indices in two ways. | ||
Either you specify a list of indices which will then target the model's input embedding layer (or, if not | ||
found, `embed_tokens`). Alternatively, you can specify a dictionary where the key is the name of the | ||
embedding module and the values are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. Note | ||
that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled. Also note that | ||
models using weight-tying are currently not supported. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adjust/delete? |
||
loftq_config (`Optional[LoftQConfig]`): | ||
The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights | ||
and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a | ||
|
@@ -444,10 +444,11 @@ class LoraConfig(PeftConfig): | |
metadata={ | ||
"help": ( | ||
"Lets you specify which token indices to selectively fine-tune without requiring to re-train the " | ||
"whole embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list " | ||
"of indices which will then target the `embed_tokens` layer, or, if your model is using a different " | ||
"layer for embedding, you can specify a dictionary where the key is the name of the embedding module " | ||
"and the values are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. " | ||
"whole embedding matrix using the `peft.TrainableTokensModel` method. You can specify token indices " | ||
"in two ways. Either you specify a list of indices which will then target the model's input embedding " | ||
"layer (or, if not found, `embed_tokens`). Alternatively, you can specify a dictionary where the key " | ||
"is the name of the embedding module and the values are the list of token indices, e.g. " | ||
"`{'embed_tokens': [0, 1, ...]}`. " | ||
"Note that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled. " | ||
"Also note that models using weight-tying are currently not supported." | ||
) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -37,24 +37,40 @@ def __init__( | |||||
base_layer: nn.Module, | ||||||
adapter_name: str, | ||||||
token_indices: list[int], | ||||||
tied_adapter: Optional[TrainableTokensLayer] = None, | ||||||
**kwargs, | ||||||
) -> None: | ||||||
super().__init__() | ||||||
|
||||||
self.base_layer = base_layer | ||||||
self._active_adapter = adapter_name | ||||||
self.token_indices = {} | ||||||
self.kwargs = kwargs | ||||||
|
||||||
self.tied_adapter = tied_adapter | ||||||
|
||||||
# we store the updated weights of particular tokens and their originals. we assume | ||||||
# that the count of new tokens is far smaller than the number of total tokens. | ||||||
self.trainable_tokens_delta = nn.ParameterDict({}) | ||||||
self.trainable_tokens_original = BufferDict({}) | ||||||
# | ||||||
# In case we have weight tying with another token adapter, we'll have no actual | ||||||
# references on our own but use everything from the tied adapter. | ||||||
if not self.tied_adapter: | ||||||
self.trainable_tokens_delta = nn.ParameterDict({}) | ||||||
self.trainable_tokens_original = BufferDict({}) | ||||||
self.token_indices = {} | ||||||
else: | ||||||
self.trainable_tokens_delta = self.tied_adapter.trainable_tokens_delta | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC, you mentioned that stuff like |
||||||
self.trainable_tokens_original = self.tied_adapter.trainable_tokens_original | ||||||
self.token_indices = self.tied_adapter.token_indices | ||||||
|
||||||
# Mark the weight as unmerged | ||||||
self.merged_adapters = [] | ||||||
|
||||||
def update_layer(self, adapter_name, **kwargs): | ||||||
if kwargs.get("tied_adapter", None): | ||||||
# in this case we don't have any say, we're just following whatever the tied | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? |
||||||
# adpater does, so we'll just return here. | ||||||
return | ||||||
|
||||||
self.token_indices[adapter_name] = kwargs["token_indices"] | ||||||
init_weights = kwargs.get("init_weights", True) | ||||||
|
||||||
|
@@ -130,6 +146,16 @@ def unmerge(self) -> None: | |||||
originals = self.trainable_tokens_original[adapter_name].to(self.base_layer.weight) | ||||||
self.base_layer.weight.data.index_copy_(dim=0, index=index, source=originals) | ||||||
|
||||||
def get_merged_weights(self, active_adapters): | ||||||
W = self.base_layer.weight | ||||||
|
||||||
for adapter_name in active_adapters: | ||||||
index = torch.tensor(self.token_indices[adapter_name]).to(W.device) | ||||||
deltas = self.trainable_tokens_delta[adapter_name].to(W) | ||||||
W = W.index_copy(dim=0, index=index, source=deltas) | ||||||
|
||||||
return W | ||||||
|
||||||
def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) -> torch.Tensor: | ||||||
if self.disable_adapters or not active_adapters: | ||||||
if self.merged: | ||||||
|
@@ -140,22 +166,31 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) -> | |||||
else: | ||||||
self._check_overlapping_tokens(active_adapters) | ||||||
|
||||||
W = self.base_layer.weight | ||||||
|
||||||
for adapter_name in active_adapters: | ||||||
index = torch.tensor(self.token_indices[adapter_name]).to(W.device) | ||||||
deltas = self.trainable_tokens_delta[adapter_name].to(W) | ||||||
W = W.index_copy(dim=0, index=index, source=deltas) | ||||||
|
||||||
result = F.embedding( | ||||||
input=x, | ||||||
weight=W, | ||||||
padding_idx=self.base_layer.padding_idx, | ||||||
max_norm=self.base_layer.max_norm, | ||||||
norm_type=self.base_layer.norm_type, | ||||||
scale_grad_by_freq=self.base_layer.scale_grad_by_freq, | ||||||
sparse=self.base_layer.sparse, | ||||||
) | ||||||
W = self.get_merged_weights(active_adapters) | ||||||
|
||||||
# Normally it should be very clear that we're wrapping Embedding layers but there are cases, such as | ||||||
# tying weights with an LM head where the layer we wrap is a Linear layer. Therefore we must choose | ||||||
# accordingly. | ||||||
if isinstance(self.base_layer, torch.nn.Embedding): | ||||||
result = F.embedding( | ||||||
input=x, | ||||||
weight=W, | ||||||
padding_idx=self.base_layer.padding_idx, | ||||||
max_norm=self.base_layer.max_norm, | ||||||
norm_type=self.base_layer.norm_type, | ||||||
scale_grad_by_freq=self.base_layer.scale_grad_by_freq, | ||||||
sparse=self.base_layer.sparse, | ||||||
) | ||||||
elif isinstance(self.base_layer, torch.nn.Linear): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would not necessarily work with quantized models, right? I wonder if we can find a more robust way of handling this, but I'm not sure how exactly. |
||||||
# Probably a tied adapter that wraps an LM head. | ||||||
result = F.linear( | ||||||
input=x, | ||||||
weight=W, | ||||||
) | ||||||
else: | ||||||
raise ValueError( | ||||||
"TrainableTokensLayer wraps an unknown layer type, maybe you are targeting the wrong layer?" | ||||||
) | ||||||
|
||||||
return result | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about a comment why this check is required?