-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LoRA Builders for MM #1661
LoRA Builders for MM #1661
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1661
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 72b0139 with merge base 34d70b4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -197,21 +527,34 @@ def flamingo_decoder( | |||
for idx in range(1, num_layers + 1): | |||
|
|||
# Self attention layers for text decoder | |||
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) | |||
self_attn = MultiHeadAttention( | |||
rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rope should be instantiated only once, ouside of the for loop, to avoid copies and extra memory
@@ -282,10 +284,12 @@ def setup(self, cfg: DictConfig) -> None: | |||
|
|||
# Dataloader depends on the tokenizer and loss_fn and should be | |||
# setup after all of these are setup | |||
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is padded_collate_sft the default? I thought that didn't work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the standard collate for finetuning text models.
@@ -545,15 +555,20 @@ def _setup_data( | |||
|
|||
if isinstance(cfg_dataset, ListConfig): | |||
datasets = [ | |||
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) | |||
config.instantiate(single_cfg_dataset, self._tokenizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't need to be keyword? Edit: oh I guess not for the builder versions, just if you're actually passing SFTDataset
directly (which I guess we won't support?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some datasets call it tokenizer and others call it transforms
activation: Callable = nn.SiLU, | ||
cls_output_dim: int = 512, | ||
attn_bias: bool = True, | ||
out_indices: Optional[List[int]] = None, | ||
output_cls_projection: bool = False, | ||
max_num_tiles: int = 4, | ||
in_channels: int = 3, | ||
intermediate_act: torch.nn.Module = torch.nn.SiLU(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at L171 and L178 makes me sad
def lora_clip_vision_encoder( | ||
lora_modules: List[LORA_ATTN_MODULES], | ||
apply_lora_to_mlp: bool = False, | ||
apply_lora_to_output: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we just need to pass this for consistency even though it's a no-op? Would maybe raise an error if it's set to true or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i like consistency, agree with raising an error saying that it is a no-op/not supported. We do it for all tied embedding models that dont support apply_lora_to_output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What can't they be true?
output: nn.Module, | ||
num_hidden_inputs: int = 0, | ||
) -> None: | ||
super().__init__() | ||
self.layers = _get_clones(layer, num_layers) | ||
self.layers = nn.ModuleList(layers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't layers be List[nn.Module]
type then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please look at the transformer, that supports all cases (List[nn.Module], nn.ModuleList, nn.module:
torchtune/torchtune/modules/transformer.py
Line 353 in 30b8519
if isinstance(layers, nn.ModuleList): |
I am fine if in flamingo it only supports nn.ModuleList though, but i prefer the consistency
# ------------------ LoRA Flamingo ------------------ | ||
|
||
|
||
class LoRATrainable(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this used?
"embed_dim": clip_embed_dim, | ||
"num_layers": clip_num_layers, | ||
"num_heads": num_heads, | ||
"activation": nn.GELU, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's deliberate but seems like this doesn't match the default in the CLIP builder, could be a potential source of confusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deliberate
|
||
def lora_flamingo_decoder( | ||
decoder_lora: bool, | ||
fusion_lora: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be used in builders
sa_norm=RMSNorm(dim=embed_dim, eps=1e-5), | ||
mlp_norm=RMSNorm(dim=embed_dim, eps=1e-5), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit: use the same format for eps (below it's 1e-05
)
apply_lora_to_mlp: bool, | ||
apply_lora_to_output: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any particular reason these are now keyword-only args as opposed to how we have them elsewhere?
if idx % fusion_interval == 0: | ||
attn = lora_llama3_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could maybe use partials to reduce the duplicative code here? But nbd either way
ca_norm=RMSNorm(dim=embed_dim), | ||
mlp_norm=RMSNorm(dim=embed_dim), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it deliberate to have different eps here vs in self-attention layers?
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example