Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove old prompt table for storing cached ptunig representations #7295

Merged
merged 12 commits into from
Aug 25, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def get_all_keys(self,):
Returns all the keys in the model
"""
k = [n for n, p in self.named_parameters()]
return set(k)
b = [n for n, p in self.named_buffers() if n in self.state_dict().keys()]
# we include buffers because ptuning representations are cached in a buffer and saved to state_dict for inference time use.
return set(k + b)

def get_peft_state_dict(self,):
"""
Expand Down Expand Up @@ -134,7 +136,6 @@ def setup_optimizer_param_groups(self):
module.set_enabled_adapters(enabled=True)
module.unfreeze_enabled_adapters() # selectively unfreeze the adapter modules.
opt_params += [p for p in module.parameters() if p.requires_grad]

self._optimizer_param_groups = ({"params": opt_params},)
logging.info(f"Optimizer groups set:\n{self.summarize()}")

Expand Down Expand Up @@ -326,13 +327,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.name_key_to_cfg = {AdapterName.PTUNING_ADAPTER: adapter_cfg}
super().__init__(cfg, trainer)
self.virtual_tokens = cfg.peft.p_tuning.virtual_tokens
self.trainable_keys = self.adapter_keys - set(
[
"model.language_model.adapter_layer.ptuning_adapter.inference_table.prompt_table.taskname.prompt_embeddings.weight",
"model.module.language_model.adapter_layer.ptuning_adapter.inference_table.prompt_table.taskname.prompt_embeddings.weight", # for Float16Model models
]
)
# we exclude the above parameter from training because it is present for backward compatibility for inference using FasterTransformer (@adithyare)

def init_peft_modules(self,):
"""
Expand Down Expand Up @@ -376,15 +370,7 @@ def load_state_dict(self, state_dict, strict: bool = True):

def setup_optimizer_param_groups(self):
if self.first_stage_of_pipeline():
# super().setup_optimizer_param_groups()
self.freeze() # Freeze the entire model
opt_params = []
for n, p in self.named_parameters():
if n in self.trainable_keys:
p.requires_grad = True
opt_params.append(p)

self._optimizer_param_groups = ({"params": opt_params},)
super().setup_optimizer_param_groups()
else:
self.freeze() # Freeze the entire model
self._optimizer_param_groups = ({"params": []},)
Expand Down Expand Up @@ -430,24 +416,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer)
self.virtual_tokens = cfg.peft.p_tuning.virtual_tokens

def setup_optimizer_param_groups(self):
super().setup_optimizer_param_groups()
guyueh1 marked this conversation as resolved.
Show resolved Hide resolved

# (guyueh1) This part is used to avoid adding frozen parameters in trainable adapter modules
# in the setup_optimizer_param_groups() of the MegatronPEFTModel class, all parameters
# in an adapter module are going to be set requires_grad=True. However in ptuning
# adapter the inference table should be untrainable. We explicitely set that parameter
# to untrainable here.
self.trainable_keys = self.adapter_keys - set(
[
"model.language_model.adapter_layer.ptuning_adapter.inference_table.prompt_table.taskname.prompt_embeddings.weight",
"model.module.language_model.adapter_layer.ptuning_adapter.inference_table.prompt_table.taskname.prompt_embeddings.weight", # for Float16Model or BFloat16Model models
]
)
for n, p in self.named_parameters():
if not (n in self.trainable_keys):
p.requires_grad_(False)


class MegatronGPTLoRAModel(MegatronGPTLayerwisePEFTModel):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from nemo.collections.common.parts.utils import activation_registry
from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, init_method_const, init_method_normal
from nemo.collections.nlp.modules.common.prompt_encoder import InferenceTable
from nemo.core.classes.mixins import adapter_mixin_strategies

try:
Expand Down Expand Up @@ -322,7 +321,8 @@
# (@adithyare) the persistent=False will not pollute the indices into the state_dict of this module.
self.register_buffer("indices", torch.LongTensor(list(range(self.virtual_tokens))), persistent=False)
self.embedding = torch.nn.Embedding(self.virtual_tokens, self.embedding_dim)
self.inference_table = InferenceTable("taskname", self.output_dim, self.virtual_tokens)
self.register_buffer("inference_table", torch.Tensor(self.virtual_tokens, self.output_dim), persistent=True)
self.is_inference_ready = False
self.first = ColumnParallelLinear(
self.embedding_dim,
self.bottleneck_dim,
Expand Down Expand Up @@ -356,13 +356,16 @@
This method caches the output representation from the Encoder and saves it inside `self.inference_table`.
"""
prompt_representation = prompt_representation.detach().clone()
self.inference_table.set_prompt_table(prompt_representation)
self.inference_table.data = prompt_representation
self.is_inference_ready = True
return True

def clear_inference_table(self,):
self.inference_table.clear_prompt_table()
self.inference_table.fill_(0.0)
self.is_inference_ready = False

def get_inference_table(self,):
return self.inference_table.get_prompt_table()
return self.inference_table.data

def inner_forward(self,):
input_embeds = self.embedding(self.indices).unsqueeze(0)
Expand All @@ -381,11 +384,12 @@
output_embeds = self.get_inference_table().unsqueeze(1)
else:
if self.training:
if self.inference_table.is_inference_ready:
if self.is_inference_ready:
self.clear_inference_table()
output_embeds = self.inner_forward()
else:
if not self.inference_table.is_inference_ready:
output_embeds = self.inner_forward()

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'output_embeds' is unnecessary as it is
redefined
before this value is used.
This assignment to 'output_embeds' is unnecessary as it is
redefined
before this value is used.
if not self.is_inference_ready:
output_embeds = self.inner_forward()
self.set_inference_table(output_embeds.squeeze(1))
output_embeds = self.get_inference_table().unsqueeze(1)
Expand Down
Loading