Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA Builders for MM #1661

Merged
merged 7 commits into from
Sep 24, 2024
Merged

LoRA Builders for MM #1661

merged 7 commits into from
Sep 24, 2024

Conversation

pbontrager
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

  • CLIP lora encoder + lora mlp + lora attention
  • Flamingo lora encoder + decoder + projection head
  • Llama 3.1 updated shared lora util
  • TODO: update lora recipes to match full finetune updates

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Sep 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1661

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 72b0139 with merge base 34d70b4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 24, 2024
@@ -197,21 +527,34 @@ def flamingo_decoder(
for idx in range(1, num_layers + 1):

# Self attention layers for text decoder
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
self_attn = MultiHeadAttention(
rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rope should be instantiated only once, ouside of the for loop, to avoid copies and extra memory

@@ -282,10 +284,12 @@ def setup(self, cfg: DictConfig) -> None:

# Dataloader depends on the tokenizer and loss_fn and should be
# setup after all of these are setup
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is padded_collate_sft the default? I thought that didn't work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the standard collate for finetuning text models.

torchtune/models/clip/_component_builders.py Show resolved Hide resolved
@@ -545,15 +555,20 @@ def _setup_data(

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
config.instantiate(single_cfg_dataset, self._tokenizer)
Copy link
Contributor

@ebsmothers ebsmothers Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be keyword? Edit: oh I guess not for the builder versions, just if you're actually passing SFTDataset directly (which I guess we won't support?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some datasets call it tokenizer and others call it transforms

Comment on lines +171 to +178
activation: Callable = nn.SiLU,
cls_output_dim: int = 512,
attn_bias: bool = True,
out_indices: Optional[List[int]] = None,
output_cls_projection: bool = False,
max_num_tiles: int = 4,
in_channels: int = 3,
intermediate_act: torch.nn.Module = torch.nn.SiLU(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at L171 and L178 makes me sad

def lora_clip_vision_encoder(
lora_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we just need to pass this for consistency even though it's a no-op? Would maybe raise an error if it's set to true or something

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i like consistency, agree with raising an error saying that it is a no-op/not supported. We do it for all tied embedding models that dont support apply_lora_to_output

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What can't they be true?

output: nn.Module,
num_hidden_inputs: int = 0,
) -> None:
super().__init__()
self.layers = _get_clones(layer, num_layers)
self.layers = nn.ModuleList(layers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't layers be List[nn.Module] type then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please look at the transformer, that supports all cases (List[nn.Module], nn.ModuleList, nn.module:

if isinstance(layers, nn.ModuleList):

I am fine if in flamingo it only supports nn.ModuleList though, but i prefer the consistency

# ------------------ LoRA Flamingo ------------------


class LoRATrainable(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

"embed_dim": clip_embed_dim,
"num_layers": clip_num_layers,
"num_heads": num_heads,
"activation": nn.GELU,
Copy link
Contributor

@ebsmothers ebsmothers Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's deliberate but seems like this doesn't match the default in the CLIP builder, could be a potential source of confusion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deliberate


def lora_flamingo_decoder(
decoder_lora: bool,
fusion_lora: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be used in builders

Comment on lines +553 to +554
sa_norm=RMSNorm(dim=embed_dim, eps=1e-5),
mlp_norm=RMSNorm(dim=embed_dim, eps=1e-5),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: use the same format for eps (below it's 1e-05)

Comment on lines +643 to +644
apply_lora_to_mlp: bool,
apply_lora_to_output: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason these are now keyword-only args as opposed to how we have them elsewhere?

Comment on lines +559 to +560
if idx % fusion_interval == 0:
attn = lora_llama3_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could maybe use partials to reduce the duplicative code here? But nbd either way

Comment on lines +593 to +594
ca_norm=RMSNorm(dim=embed_dim),
mlp_norm=RMSNorm(dim=embed_dim),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it deliberate to have different eps here vs in self-attention layers?

@pbontrager pbontrager merged commit 18efc81 into pytorch:main Sep 24, 2024
17 checks passed
@pbontrager pbontrager deleted the mm_lora branch September 24, 2024 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants