Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove old prompt table for storing cached ptunig representations #7295

Merged
merged 12 commits into from
Aug 25, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.virtual_tokens = cfg.peft.p_tuning.virtual_tokens
self.trainable_keys = self.adapter_keys - set(
[
"model.language_model.adapter_layer.ptuning_adapter.inference_table.prompt_table.taskname.prompt_embeddings.weight",
"model.module.language_model.adapter_layer.ptuning_adapter.inference_table.prompt_table.taskname.prompt_embeddings.weight", # for Float16Model models
"model.language_model.adapter_layer.ptuning_adapter.inference_table.weight"
"model.module.language_model.adapter_layer.ptuning_adapter.inference_table.weight" # for Float16Model models
]
)
# we exclude the above parameter from training because it is present for backward compatibility for inference using FasterTransformer (@adithyare)

def init_peft_modules(self,):
"""
Expand Down Expand Up @@ -430,24 +429,23 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
}
super().__init__(cfg, trainer)
self.virtual_tokens = cfg.peft.p_tuning.virtual_tokens

def setup_optimizer_param_groups(self):
super().setup_optimizer_param_groups()
guyueh1 marked this conversation as resolved.
Show resolved Hide resolved

# (guyueh1) This part is used to avoid adding frozen parameters in trainable adapter modules
# in the setup_optimizer_param_groups() of the MegatronPEFTModel class, all parameters
# in an adapter module are going to be set requires_grad=True. However in ptuning
# adapter the inference table should be untrainable. We explicitely set that parameter
# to untrainable here.
self.trainable_keys = self.adapter_keys - set(
[
"model.language_model.adapter_layer.ptuning_adapter.inference_table.prompt_table.taskname.prompt_embeddings.weight",
"model.module.language_model.adapter_layer.ptuning_adapter.inference_table.prompt_table.taskname.prompt_embeddings.weight", # for Float16Model or BFloat16Model models
"model.language_model.adapter_layer.ptuning_adapter.inference_table.weight"
guyueh1 marked this conversation as resolved.
Show resolved Hide resolved
"model.module.language_model.adapter_layer.ptuning_adapter.inference_table.weight" # for Float16Model models
]
)

def setup_optimizer_param_groups(self):
self.freeze() # Freeze the entire model
opt_params = []
for n, p in self.named_parameters():
if not (n in self.trainable_keys):
p.requires_grad_(False)
if n in self.trainable_keys:
p.requires_grad = True
opt_params.append(p)

self._optimizer_param_groups = ({"params": opt_params},)
logging.info(f"Optimizer groups set:\n{self.summarize()}")


class MegatronGPTLoRAModel(MegatronGPTLayerwisePEFTModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from nemo.collections.common.parts.utils import activation_registry
from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, init_method_const, init_method_normal
from nemo.collections.nlp.modules.common.prompt_encoder import InferenceTable
from nemo.core.classes.mixins import adapter_mixin_strategies

try:
Expand Down Expand Up @@ -322,7 +321,9 @@ def __init__(
# (@adithyare) the persistent=False will not pollute the indices into the state_dict of this module.
self.register_buffer("indices", torch.LongTensor(list(range(self.virtual_tokens))), persistent=False)
self.embedding = torch.nn.Embedding(self.virtual_tokens, self.embedding_dim)
self.inference_table = InferenceTable("taskname", self.output_dim, self.virtual_tokens)
self.inference_table = nn.Embedding(self.virtual_tokens, self.output_dim)
self.inference_table.requires_grad = False
self.is_inference_ready = False
self.first = ColumnParallelLinear(
self.embedding_dim,
self.bottleneck_dim,
Expand Down Expand Up @@ -356,13 +357,16 @@ def set_inference_table(self, prompt_representation: torch.Tensor):
This method caches the output representation from the Encoder and saves it inside `self.inference_table`.
"""
prompt_representation = prompt_representation.detach().clone()
self.inference_table.set_prompt_table(prompt_representation)
self.inference_table.weight.data = prompt_representation
self.is_inference_ready = True
return True

def clear_inference_table(self,):
self.inference_table.clear_prompt_table()
self.is_inference_ready = False
self.inference_table.weight.data.fill_(0.0)

def get_inference_table(self,):
return self.inference_table.get_prompt_table()
return self.inference_table.weight.data

def inner_forward(self,):
input_embeds = self.embedding(self.indices).unsqueeze(0)
Expand All @@ -381,11 +385,11 @@ def forward(self, batch_size: int, use_cached_reps: bool = False) -> torch.Tenso
output_embeds = self.get_inference_table().unsqueeze(1)
else:
if self.training:
if self.inference_table.is_inference_ready:
if self.is_inference_ready:
self.clear_inference_table()
output_embeds = self.inner_forward()
else:
if not self.inference_table.is_inference_ready:
if not self.is_inference_ready:
output_embeds = self.inner_forward()
self.set_inference_table(output_embeds.squeeze(1))
output_embeds = self.get_inference_table().unsqueeze(1)
Expand Down
Loading