Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Locking of Text Tower for CLIP Models #523

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

rsomani95
Copy link
Contributor

Currently, the lock-text.* args only work when using a HFTextEncoder. This PR adds a lock_text_tower method to the 'native' CLIP class as well.

There's two aspects that are different from the HFTextEncoder that I'd love feedback on:

  1. embeddings in this case are two separate layers, whereas HF has just one
  2. self.text_projection and self.ln_final are passed in as the last two layers respectively. The HF class doesn't have an analogue. So, to unfreeze the last transformer block, you need to pass in lock-text-unlocked-layers 3, which feels awkward.

@rwightman
Copy link
Collaborator

rwightman commented May 15, 2023

@rsomani95 thanks for the contribution

For this to work in all current cases, the bulk of the logic needs to be in a free function because this should work with both CLIP and CustomTextCLIP (w/ builtin text-tower) via the TextTransformer module, So the TextTransformer needs a lock method that calls the free fn, and the CLIP needs the lock_text_tower method you have that also calls the free fn, both passing in the list of modules/parameters...

Looking at the vision lock as a template, the grouping strategy used there is a solultion to your awkwardness and we should be consistent with that, we grouped the input + abs pos embed + any pre-ln into the first group, each transformer block as another group, lumped the last transformer block with the final LN, and final projection as its own...

        if unlocked_groups != 0:
            groups = [
                [
                    self.conv1,
                    self.class_embedding,
                    self.positional_embedding,
                    self.ln_pre,
                ],
                *self.transformer.resblocks[:-1],
                [
                    self.transformer.resblocks[-1],
                    self.ln_post,
                ],
                self.proj,
            ]

            def _unlock(x):
                if isinstance(x, Sequence):
                    for g in x:
                        _unlock(g)
                else:
                    if isinstance(x, torch.nn.Parameter):
                        x.requires_grad = True
                    else:
                        for p in x.parameters():
                            p.requires_grad = True

- add free fn that can freeze both `TextTransformer` and `CLIP`
@rsomani95
Copy link
Contributor Author

@rwightman that freezing code is a lot cleaner, I've updated mine to mimic the vision lock.

Also implemented the free function as I understood it. Let me know if this is in line with what you were describing above.

@dfan
Copy link

dfan commented Oct 25, 2023

The lock_text_transformer function is not correct because it doesn't add a case for handling a list of modules. It will raise a TypeError

@rsomani95
Copy link
Contributor Author

@dfan with which model do you encounter that error?

@Yangr116
Copy link

Yangr116 commented Jan 5, 2024

The token_embedding is nn.Embedding, so it will cause a type error.

@rohun-tripathi
Copy link

@rsomani95 With ViT-H-14-378-quickgelu and ViT-H-14-336 I see a type error.
Which models did you test?

I temporarily resolved it using the following version but am yet to run a complete training -

def lock_text_transformer(transformer: TextTransformer, unlocked_layers: int = 0, freeze_layer_norm: bool = True):

    groups = [
        [transformer.token_embedding, transformer.positional_embedding],
        *transformer.transformer.resblocks[:-1],
        [transformer.transformer.resblocks[ -1], transformer.ln_final],
        transformer.text_projection,
    ]

    def _freeze(modules, freeze_layer_norm: bool = True):
        for module in modules:
            # `CLIP.text_projection` and `CLIP.positional_embedding`
            if isinstance(module, nn.Parameter):
                print(f"Freezing module {module} which is a nn.Parameter")
                module.requires_grad = False

            # All other modules
            elif isinstance(module, nn.Module):
                print(f"Freezing module {module} which is a nn.Module")
                for n, p in module.named_parameters():
                    p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False

            elif isinstance(module, list):
                for sub_module in module:
                    
                    if isinstance(sub_module, nn.Parameter):
                        print(f"Freezing sub_module {sub_module} which is a nn.Parameter")
                        sub_module.requires_grad = False
                    
                    elif isinstance(sub_module, nn.Module):
                        print(f"Freezing sub_module {sub_module} which is a nn.Module")
                        for n, p in sub_module.named_parameters():
                            p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False

                    else:
                        raise TypeError(f"Encountered unexpected module type {type(sub_module)} for sub_module {sub_module}")
            else:
                raise TypeError(f"Encountered unexpected module type {type(module)} for module {module}")

Interpause added a commit to Interpause/open_clip that referenced this pull request May 23, 2024
@rsomani95
Copy link
Contributor Author

rsomani95 commented Jun 18, 2024

I think this should be good to go now.

@rohun-tripathi the updated code is a lot simpler, and should work functionally just as well. The LayerNorm freezing code I'd copied over naively was incorrect - I've fixed that now too. @dfan @Yangr116 curious if this works with whatever architectures you'll tried this with.

@rsomani95
Copy link
Contributor Author

@rwightman I couldn't think of a more robust way to extract the LayerNorm layers inside the resblocks:

elif isinstance(x, torch.nn.LayerNorm):
for p in x.parameters():
p.requires_grad = ln_status
else:
for n,p in x.named_parameters():
# This should grab LayerNorm inside `ResidualAttentionBlock` blocks
if n.startswith("ln_"):
p.requires_grad = ln_status
else:
p.requires_grad = True

The assumption here is that the only norm layer being used is LayerNorm. I think, but am not 100% sure, that all existing text models are using LayerNorm?

rsomani95 added a commit to Synopsis/open_clip that referenced this pull request Jun 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants