Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Network Arch Code refactoring to allow replacing the last FC and add a freeze core parameter #6552

Closed
AHarouni opened this issue May 24, 2023 · 12 comments
Assignees

Comments

@AHarouni
Copy link

Describe the solution you'd like
Once there is a good model out there (either we trained it, from model zoo, or other sources), we would like to easily replace the last FC layer with 104 outputs for total segmentor for example with a different number of outputs. user would get errors when loading the checkpoint and need to do some network surgery.

Second request is usually user would want to load the check point then freeze the core parameters and just train the newly added FC.

Describe alternatives you've considered
Code should be refactored to allow to give the last FC a different name. with that the check point would be loaded without any errors
as in Project-MONAI/MONAILabel#1298 I had to write my own segresnet below to change the name of the last layer

class SegResNetNewFC(SegResNet):
    def __init__(self,spatial_dims: int = 3,
            init_filters: int = 8,out_channels: int = 2,
            **kwargs,):
        super().__init__(spatial_dims=spatial_dims , init_filters=init_filters
                         ,use_conv_final=False,**kwargs)
        self.conv_final=None # delete this layer
        self.conv_final_xx = self._make_final_conv(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        beforeFC = super().forward(x) ## will build everything without the last FC since we init with use_conv_final =False
        x = self.conv_final_xx(beforeFC)
        return x

Additional context
We do have code to frezze the layers but it needs to be a utility that calls into a function of the network to get FC layer name. For this we should have an interface for network models to implement so this util function can be called
I have it hacked as

  def optimizer(self, context: Context):
      freezeLayers(self._network, ["conv_final_xx"], True)
      context.network = self._network
      return torch.optim.AdamW(context.network.parameters(), lr=1e-4, weight_decay=1e-5) ## too small

def freezeLayers(network, exclude_vars, exclude_include=True):
    src_dict = get_state_dict(network)
    to_skip={''} # initalize single string dict 
    for exclude_var in exclude_vars:
        s = {s_key for s_key in src_dict if exclude_var and re.compile(exclude_var).search(s_key)}
        to_skip.update(s)
    to_skip.remove('')  ## remove the empty string
    logger.error(f"--------------------- layer freeze with {exclude_include=}")
    for name, para in network.named_parameters():
        if exclude_include:
            if name not in to_skip:
                logger.info(f"------------ freezing {name} size {para.shape}")
                para.requires_grad = False
            else:
                logger.info(f"----training ------- {name} size {para.shape}")
        else:
            if name in to_skip:
                logger.info(f"------------ freezing {name} size {para.shape}")
                para.requires_grad = False
            else:
                logger.info(f"----training ------- {name} size {para.shape}")
@Nic-Ma
Copy link
Contributor

Nic-Ma commented May 25, 2023

Hi @wyli @ericspod ,

Could you please help share some comments about this question?

Thanks in advance.

@ericspod
Copy link
Member

Right now the networks which do have final layers and named layers/blocks can be substituted directly, and the optimiser can be given only the parameters for those layers:

import monai
import torch
net=monai.networks.nets.SegResNet()
print(net(torch.rand(1,1,8,8,8)).shape)  # torch.Size([1, 2, 8, 8, 8])
net.conv_final=monai.networks.blocks.Convolution(3,8,5,kernel_size=1)  # 5 output channels instead of 2
print(net(torch.rand(1,1,8,8,8)).shape)  # torch.Size([1, 5, 8, 8, 8])

opt = torch.optim.Adam(net.conv_final.parameters())

I'm not sure what more you're looking for here other than some utilities to help doing this and renaming some internals of networks? We could add some helpers but often the code for a forward method is going to need adapting anyway so making a subclass would still make sense. For changing names in existing networks this breaks compatibility with existing saved states so we need to be very careful on what sort of changes we feel are justified for such a refactor.

@AHarouni
Copy link
Author

Hi @ericspod
Thanks for your response. I think you get what I am trying to do. I understand that I can do it by writing my own code (Both me and you have show this above). I am just asking the responsibility of these code changes to be transferred back the developer/ researcher who is creating the network rather than the consuming user who simply wants to finetune a well established model.
My point is this is a typical workflow that is expected since AlexNet and resnet came out. So may be to clarify my ask, I think fine tuning should:

  • should be provided to the user as out of the box with some parameter changes when calling the network or may be add another function call.
  • user should not need to go into the code and find the name of the last layer conv_final

May be we can create a class that all networks should implement, where they are forced to add a function is the networks to change the last layer as

def change_last_FC(self,new_output_classes:int):
    ## here is where that layer change would happen

For the freezing option, again I don't want the user to know the details of the architecture and the names of which parameters to pass to the optimizer. This will be more complicated is the last layer of FC is a set of layers as in segresnet (see below)
Instead may be we can have yet another function as part of the network to this freezing by setting the requires_grad to false

So the steps to do fine tunning should be as :
1- create network as default
2- load check point
3- call net.change_last_FC(new_output_classes=3)
4- call net.freeze()
5- continue rest of your code as normal
This should be minimal changes to workflows such as the ones we use in monai label

May be to clarify this isolation need, if you look in the segresnet code you will see that the last layer is more than just conv, it is normalization relu and conv. User should not worry about these details

    def _make_final_conv(self, out_channels: int):
        return nn.Sequential(
            get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters),
            self.act_mod,
            get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True),
        )

I agree with you that "For changing names in existing networks this breaks compatibility with existing saved states so we need to be very careful" I think that adding this interface and having the network implement it should not break check point but I might be wrong. Also exporting ts sometimes has issues with parent classes that I faced before that should be tested

@ericspod
Copy link
Member

Hi @AHarouni thanks for the outline. I think it's clear what you want to do but I don't think the changes to networks really justifies this. Even if we can make changes that don't break existing save states, these would be for existing network only and would adhere to a pattern we couldn't expect others to adopt. We had always wanted components in MONAI to be loosely coupled so didn't impose requirements on network architecture. I don't really want to move away from this in exchange for a mechanism that simplifies a relatively small aspect of network training.

We have in some networks like the Regressor some aspect of what you want and we could possibly change others, but this covers only one sort of fine tuning that someone might want to do and isn't a general solution to that problem. This works for classifiers where you want to change what is being classified and how many classes by adapting a final layer, but other architectures wouldn't necessarily be designed in a way that permits changing just one aspect of a network so simply.

During training you may also want to have gradients through the network as normal and optimise the new layer as normal but the rest of the network at a much slower rate. A freeze function of some sort to adjust training parameters of a network might make sense for this or your training regime, but I don't want to add new method requirements to networks.

As I said I see where you're coming from and what the solution is for your fine tuning strategy, but I don't see it as a general enough solution to be worth the consequences. I'm less worried about requiring users to delve into the workings of a network to figure out how to adapt components and then do fine tuning, I imagine people will need to know about the inner workings quite often to know what can be done anyway except in simple cases. If you wanted to propose a more concrete set of changes or a full PR we'd be happy to discuss it there as well as here, some of the components at least of what you want to add seem doable. Thanks!

CC @Nic-Ma @wyli @atbenmurray for any comments?

@AHarouni
Copy link
Author

Hi @ericspod
I am glad you get what I am trying to do. I also see and understand your point. This has always been the challenge, giving researchers as much flexibility as possible and adding constraints to enable workflows.

For me to figure out next steps I would like to know:

  1. Do you have any objections on having an interface to support fine tuning workflow?
  2. Is it ok to have it as an optional interface, which when implemented it would support this fine tuning workflow.
  3. For older architectures, we can create a new network for each architecture to inherit this interface.

If it is ok then the issue is coming up with this interface. I can then propose one and show it works with a fine tune workflow using one network as an example. I can then open issue for each network architecture to create a new network to support that interface. It is up to the researcher to respond or may be nvidia's developers can contribute it.

Problem is, the solution I proposed above doesn't work for monai label as the loading of weights are through handlers. I have been thinking more and more about this. The core problem is when loading the check point the last layer dimensions don't match. What do you think about adding a function in the network class that would chop off the weights of the FC or layer the researcher thinks should be finetuned? That way the network creation in the fine tune or the normal training would remain the same, including loading from check point to continue training.

@ericspod
Copy link
Member

Hi @AHarouni,
I'm good with you proposing a solution to your problem (eg. as a draft PR) then we can all discuss and comment on it. Our design choices are important to maintain but some other design of utilities that help with fine tuning would be a good addition. For the checkpoint loading I think we'd want to look at a different loader or modifications to the existing one to allow for ignoring missing keys in a particular way, and a separate handler to do the network surgery afterwards.

@AHarouni
Copy link
Author

I just found out this tutorial https://github.com/Project-MONAI/tutorials/blob/main/self_supervised_pretraining/swinunetr_pretrained/swinunetr_finetune.ipynb
specifically section below. I hope it is clear that this code should belong in the network class it self

if use_pretrained is True:
    print("Loading Weights from the Path {}".format(pretrained_path))
    ssl_dict = torch.load(pretrained_path)
    ssl_weights = ssl_dict["model"]

    # Generate new state dict so it can be loaded to MONAI SwinUNETR Model
    monai_loadable_state_dict = OrderedDict()
    model_prior_dict = model.state_dict()
    model_update_dict = model_prior_dict

    del ssl_weights["encoder.mask_token"]
    del ssl_weights["encoder.norm.weight"]
    del ssl_weights["encoder.norm.bias"]
    del ssl_weights["out.conv.conv.weight"]
    del ssl_weights["out.conv.conv.bias"]

    for key, value in ssl_weights.items():
        if key[:8] == "encoder.":
            if key[8:19] == "patch_embed":
                new_key = "swinViT." + key[8:]
            else:
                new_key = "swinViT." + key[8:18] + key[20:]
            monai_loadable_state_dict[new_key] = value
        else:
            monai_loadable_state_dict[key] = value

    model_update_dict.update(monai_loadable_state_dict)
    model.load_state_dict(model_update_dict, strict=True)
    model_final_loaded_dict = model.state_dict()

    # Safeguard test to ensure that weights got loaded successfully
    layer_counter = 0
    for k, _v in model_final_loaded_dict.items():
        if k in model_prior_dict:
            layer_counter = layer_counter + 1

            old_wts = model_prior_dict[k]
            new_wts = model_final_loaded_dict[k]

            old_wts = old_wts.to("cpu").numpy()
            new_wts = new_wts.to("cpu").numpy()
            diff = np.mean(np.abs(old_wts, new_wts))
            print("Layer {}, the update difference is: {}".format(k, diff))
            if diff == 0.0:
                print("Warning: No difference found for layer {}".format(k))
    print("Total updated layers {} / {}".format(layer_counter, len(model_prior_dict)))
    print("Pretrained Weights Succesfully Loaded !")


@ericspod
Copy link
Member

ericspod commented Jun 15, 2023

Hi @AHarouni I see the usefulness of this code for this network but it's implementing behaviour that can be generalised so that the same adaptation can be applied to other networks. This would allow us to filter a state dictionary for any network with a function to filter members of the incoming data using a function or some translation table. Other networks would benefit from the same thing with only minor differences so I think it makes sense as a separate utility function. This way we don't rely on a method of a network being present to do this, this would include non-MONAI networks as well.

I don't like the idea of requiring certain implementation details of our networks, we've always said that we want to maintain architectural similarity with Pytorch and compatibility with the existing Pytorch code as much as possible. The methods of a network should be concerned solely with the construction and operation of the network so it mixes purposes I feel to add methods of this type. We would also not be able to use this functionality with non-MONAI networks as I said, but decoupled utilities which make few assumptions about the networks they operate on are more flexible in that they can be used with such networks.

I would still suggest that we can define generalised utilities, for example the cell above can be reimplemented as a general function:

def load_adapted_state_dict(network, new_weights, filter_func):
    # Generate new state dict so it can be loaded to MONAI SwinUNETR Model
    model_prior_dict = network.state_dict()
    model_update_dict = dict(model_prior_dict)

    for key, value in new_weights.items():
        new_pair = filter_func(key, value)
        if new_pair is not None:
            model_update_dict[new_pair[0]] = new_pair[1]

    network.load_state_dict(model_update_dict, strict=True)
    model_final_loaded_dict = network.state_dict()

    # Safeguard test to ensure that weights got loaded successfully
    layer_counter = 0
    for k, _v in model_final_loaded_dict.items():
        if k in model_prior_dict:
            layer_counter = layer_counter + 1

            old_wts = model_prior_dict[k]
            new_wts = model_final_loaded_dict[k]

            old_wts = old_wts.to("cpu").numpy()
            new_wts = new_wts.to("cpu").numpy()
            diff = np.mean(np.abs(old_wts, new_wts))
            print("Layer {}, the update difference is: {}".format(k, diff))
            if diff == 0.0:
                print("Warning: No difference found for layer {}".format(k))
    print("Total updated layers {} / {}".format(layer_counter, len(model_prior_dict)))
    print("Pretrained Weights Succesfully Loaded !")


def _filter(k, v):
    if k in [
        "encoder.mask_token",
        "encoder.norm.weight",
        "encoder.norm.bias",
        "out.conv.conv.weight",
        "out.conv.conv.bias",
    ]:
        return None

    if k[:8] == "encoder.":
        if k[8:19] == "patch_embed":
            new_key = "swinViT." + k[8:]
        else:
            new_key = "swinViT." + k[8:18] + k[20:]

        return new_key, v
    else:
        return k, v


load_adapted_state_dict(model, torch.load(pretrained_path)["model"], _filter)

@AHarouni
Copy link
Author

Hi @ericspod
Nice work generalizing the code. I agree with your points and perspective.

I think the main thing we disagree about is the responsibilities. I believe if I would like to use a network and finetune it I should NOT need to know anything about the network layers and which layers to copy weights to and which to ignore. It should be the responsibility of the researcher who created the network.
if I understood you correctly, (please correct me if I misunderstood) you think it is ok for the user to dig in and understand the network layer names and which should be copied and which should be ignored.

So basically who should write the _filter function ? I think it should be the researcher responsibility not the user trying to finetune the model

@ericspod
Copy link
Member

The general purpose code would be something that does inspect the members of a network like what I proposed here. Researchers providing their own networks can also provide functions using this general purpose code to tweak the members they know should be changed. I think in general though this will suffice for only a very small number of cases where you want to fine tuning or refinement, that is only the things the implementors anticipate someone wanting to do whereas there's probably many other things people would want to do which will require understanding the network's inner workings anyway.

Either way the code shouldn't be part of the class definition because it's counter to our architectural ideas and introduces close coupling between the network and fine tuning. It's fine for the implementor of the network to provide their own fine tuning functions, if we have a more involved architectural pattern we can think of what classes or other components to add to MONAI to facilitate this.

@KumoLiu
Copy link
Contributor

KumoLiu commented Aug 28, 2023

I think this one can be partially addressed by the enhanced load API in this PR.

@ericspod
Copy link
Member

I think this one can be partially addressed by the enhanced load API in this PR.

There is definitely overlap with the concepts here, with the load process you are modifying the loaded network definition. Here however it's more about modifying weights and leaving the network structure initially unchanged and then later, for example, adapting final layers to suit different numbers of classes. Within the context of bundles we would want to modify the network structure and then load weight data with a filtering process like that described above, so to do this cleanly we'd need a handler class as well instead of the CheckpointLoader we typically use.

@wyli wyli assigned wyli and KumoLiu Sep 5, 2023
@KumoLiu KumoLiu mentioned this issue Sep 12, 2023
6 tasks
wyli pushed a commit that referenced this issue Sep 12, 2023
Part of #6552.

### Description
Add `freeze_layers`.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <yunl@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
wyli pushed a commit that referenced this issue Sep 12, 2023
…6917)

Part of #6552.

### Description
After PR #6835, we have added `copy_model_args` in the `load` API which
can help us update the state_dict flexibly.

https://github.com/KumoLiu/MONAI/blob/93a149a611b66153cf804b31a7b36a939e2e593a/monai/bundle/scripts.py#L397

Given this [issue](#6552),
we need to be able to filter the model's weights flexibly.
In `copy_model_state`, we already have a "mapping" arg, the filter will
be more flexible if we can support regular expression in the mapping.
This PR mainly added the support for regular expression for "mapping"
arg.

In the
[example](#6552 (comment))
in this [issue](#6552),
after this PR, we can do something like:
```
exclude_vars = "encoder.mask_token|encoder.norm.weight|encoder.norm.bias|out.conv.conv.weight|out.conv.conv.bias"
mapping={"encoder.layers(.*).0.0.": "swinViT.layers(.*).0."}
dst_dict, updated_keys, unchanged_keys = copy_model_state(
       model, ssl_weights, exclude_vars=exclude_vars, mapping=mapping
)
```

Additionally, based on the comments of Eric
[here](#6552 (comment)),
I totally agree, we could add a handler to make the pipeline easier to
implement, but perhaps this task is no need to set as a "BundleTodo" for
MONAIv1.3 but as an enhancement for MONAI near future.
What do you think? @ericspod @wyli @Nic-Ma 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <yunl@nvidia.com>
@KumoLiu KumoLiu closed this as completed Sep 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
Status: Done
Development

No branches or pull requests

5 participants