Skip to content

Network Arch Code refactoring to allow replacing the last FC and add a freeze core parameter  #6552

Closed
@AHarouni

Description

@AHarouni

Describe the solution you'd like
Once there is a good model out there (either we trained it, from model zoo, or other sources), we would like to easily replace the last FC layer with 104 outputs for total segmentor for example with a different number of outputs. user would get errors when loading the checkpoint and need to do some network surgery.

Second request is usually user would want to load the check point then freeze the core parameters and just train the newly added FC.

Describe alternatives you've considered
Code should be refactored to allow to give the last FC a different name. with that the check point would be loaded without any errors
as in Project-MONAI/MONAILabel#1298 I had to write my own segresnet below to change the name of the last layer

class SegResNetNewFC(SegResNet):
    def __init__(self,spatial_dims: int = 3,
            init_filters: int = 8,out_channels: int = 2,
            **kwargs,):
        super().__init__(spatial_dims=spatial_dims , init_filters=init_filters
                         ,use_conv_final=False,**kwargs)
        self.conv_final=None # delete this layer
        self.conv_final_xx = self._make_final_conv(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        beforeFC = super().forward(x) ## will build everything without the last FC since we init with use_conv_final =False
        x = self.conv_final_xx(beforeFC)
        return x

Additional context
We do have code to frezze the layers but it needs to be a utility that calls into a function of the network to get FC layer name. For this we should have an interface for network models to implement so this util function can be called
I have it hacked as

  def optimizer(self, context: Context):
      freezeLayers(self._network, ["conv_final_xx"], True)
      context.network = self._network
      return torch.optim.AdamW(context.network.parameters(), lr=1e-4, weight_decay=1e-5) ## too small

def freezeLayers(network, exclude_vars, exclude_include=True):
    src_dict = get_state_dict(network)
    to_skip={''} # initalize single string dict 
    for exclude_var in exclude_vars:
        s = {s_key for s_key in src_dict if exclude_var and re.compile(exclude_var).search(s_key)}
        to_skip.update(s)
    to_skip.remove('')  ## remove the empty string
    logger.error(f"--------------------- layer freeze with {exclude_include=}")
    for name, para in network.named_parameters():
        if exclude_include:
            if name not in to_skip:
                logger.info(f"------------ freezing {name} size {para.shape}")
                para.requires_grad = False
            else:
                logger.info(f"----training ------- {name} size {para.shape}")
        else:
            if name in to_skip:
                logger.info(f"------------ freezing {name} size {para.shape}")
                para.requires_grad = False
            else:
                logger.info(f"----training ------- {name} size {para.shape}")

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions