-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPT Tuner #2168
CPT Tuner #2168
Changes from 14 commits
92b9e1a
e54d380
023f071
54cddaf
2dfe70f
ba4b115
f8c8317
bd2fc70
b01b214
6ed1723
77bb0b9
dbcdedf
f7138d4
0a5fb20
7206db5
24b0af9
81ffa09
130ec76
70067d8
9397314
249713c
0a43473
144f042
97449da
dacb400
cc348a4
79959d1
7eea892
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# Context-Aware Prompt Tuning (CPT) | ||
|
||
[Context-aware Prompt Tuning: Advancing In-Context Learning with Adversarial Methods (CPT)](https://huggingface.co/papers/2410.17222) combines In-Context Learning (ICL) with Prompt Tuning (PT) and adversarial optimization to improve few-shot learning by refining context embeddings. CPT optimizes only context tokens, which minimizes overfitting and enhances performance on classification tasks. | ||
|
||
The abstract from the paper is: | ||
|
||
*Traditional fine-tuning is effective but computationally intensive, as it requires updating billions of parameters. CPT, inspired by ICL, PT, and adversarial attacks, refines context embeddings in a parameter-efficient manner. By optimizing context tokens and applying a controlled gradient descent, CPT achieves superior accuracy across various few-shot classification tasks, showing significant improvement over existing methods such as LoRA, PT, and ICL.* | ||
|
||
## CPTConfig | ||
|
||
[[autodoc]] tuners.cpt.config.CPTConfig | ||
|
||
## CPTModel | ||
|
||
[[autodoc]] tuners.cpt.model.CPTModel | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from .config import CPTConfig | ||
from .model import CPTEmbedding | ||
|
||
|
||
__all__ = ["CPTConfig", "CPTEmbedding"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import enum | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
from peft.config import PromptLearningConfig | ||
from peft.utils import PeftType | ||
|
||
|
||
|
||
class CPTPromptInit(str, enum.Enum): | ||
"""Enum for specifying the initialization method for CPT.""" | ||
|
||
TEXT = "TEXT" # Initialize using text-based embeddings. | ||
RANDOM = "RANDOM" # Initialize randomly. | ||
|
||
|
||
@dataclass | ||
class CPTConfig(PromptLearningConfig): | ||
""" | ||
CPT Configuration class extending PeftConfig for Context-aware Prompt Tuning (CPT). | ||
|
||
This class introduces additional parameters required for CPT, such as: | ||
- Token type masks | ||
- Prompt tuning initialization | ||
- Loss weighting | ||
- Projection settings | ||
|
||
For more details, see the paper: https://arxiv.org/abs/2410.17222 | ||
""" | ||
|
||
# Token-related configurations | ||
cpt_token_ids: Optional[list[int]] = field( | ||
default=None, metadata={"help": "Tensor of token IDs used for CPT prompts."} | ||
) | ||
cpt_mask: Optional[list[int]] = field(default=None, metadata={"help": "Tensor mask applied to CPT tokens."}) | ||
cpt_tokens_type_mask: Optional[list[int]] = field( | ||
default=None, metadata={"help": "Mask indicating the type of each CPT token."} | ||
) | ||
|
||
# Prompt tuning initialization method | ||
cpt_prompt_init: Optional[str] = field( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It already exists in the code.
I changed it. |
||
default="TEXT", metadata={"help": "Initialization method: 'TEXT' for embedding-based, 'RANDOM' for random."} | ||
) | ||
|
||
# Loss-related configurations | ||
opt_weighted_loss_type: Optional[str] = field( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could change the type to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still relevant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My suggestion is to change the type annotation to |
||
default="none", metadata={"help": "Type of weighted loss: 'none' or 'decay'."} | ||
) | ||
opt_loss_decay_factor: Optional[float] = field( | ||
default=1.0, metadata={"help": "Factor for exponential decay in loss weighting."} | ||
) | ||
|
||
# Projection-related configurations | ||
opt_projection_epsilon: Optional[float] = field( | ||
default=0.1, metadata={"help": "Epsilon value for input projection."} | ||
) | ||
opt_projection_format_epsilon: Optional[float] = field( | ||
default=0.1, metadata={"help": "Epsilon value for format projection."} | ||
) | ||
|
||
# Tokenizer configuration | ||
tokenizer_name_or_path: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" | ||
}, | ||
) | ||
|
||
# Virtual token configurations | ||
num_virtual_tokens: int = field(default=0, metadata={"help": "Number of virtual tokens used in the prompt."}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is 0 a sensible default for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think having 0 as the default here makes little sense. WDYT about using a good default here, say, 10? |
||
|
||
# CPT-specific static attributes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are not supposed to be modified by the user, right? In that case, let's move them inside of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WDYT about this suggestion? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it ever make sense to let users pass these arguments? If not, I would remove them here and place them inside the |
||
is_prompt_learning = True # Indicates that CPT is a prompt-learning method. | ||
num_layers = None # Number of layers (optional, not always required). | ||
token_dim = None # Dimension of token embeddings. | ||
num_attention_heads = None # Number of attention heads (if applicable). | ||
task_type = "CAUSAL_LM" # Specifies that CPT is used for causal language modeling. | ||
num_transformer_submodules = 1 # Number of transformer submodules used. | ||
|
||
def __post_init__(self): | ||
""" | ||
Post-initialization hook to set additional attributes after the config is initialized. | ||
""" | ||
self.peft_type = PeftType.CPT # Specifies that the PEFT type is CPT. | ||
self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling. | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_token_ids is None: | ||
raise ValueError( | ||
f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " | ||
f"cpt_token_ids can't be None." | ||
) | ||
if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_mask is None: | ||
raise ValueError( | ||
f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " | ||
f"cpt_mask can't be None." | ||
) | ||
if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_tokens_type_mask is None: | ||
raise ValueError( | ||
f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " | ||
f"cpt_tokens_type_mask can't be None." | ||
) | ||
if (self.cpt_prompt_init == CPTPromptInit.TEXT) and not (len(self.cpt_token_ids) == len(self.cpt_mask) == len(self.cpt_tokens_type_mask) == self.num_virtual_tokens): | ||
raise ValueError( | ||
f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " | ||
f"cpt_token_ids, cpt_mask and cpt_tokens_type_mask must have the same length." | ||
) | ||
|
||
if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_token_ids is not None: | ||
raise ValueError( | ||
f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " | ||
f"cpt_token_ids must be None." | ||
) | ||
if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_mask is not None: | ||
raise ValueError( | ||
f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " | ||
f"cpt_mask must be None." | ||
) | ||
if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_tokens_type_mask is not None: | ||
raise ValueError( | ||
f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " | ||
f"cpt_tokens_type_mask must be None." | ||
) | ||
if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.num_virtual_tokens == 0: | ||
raise ValueError( | ||
f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " | ||
f"num_virtual_tokens must be greater than zero." | ||
) | ||
if (self.cpt_prompt_init != CPTPromptInit.RANDOM) and (self.cpt_prompt_init != CPTPromptInit.TEXT): | ||
raise ValueError( | ||
f"prompt_tuning_init must be 'RANDOM' or 'TEXT'" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no
CPTModel
, onlyCPTEmbedding
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I add a link to an example in the cpt.md file? I didn't see any other methods linked to examples. If you have an example, it would be a great help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I see no reason not to add a link there. Just be aware that the link would be invalid until the PR is merged.