Description
Describe the solution you'd like
Once there is a good model out there (either we trained it, from model zoo, or other sources), we would like to easily replace the last FC layer with 104 outputs for total segmentor for example with a different number of outputs. user would get errors when loading the checkpoint and need to do some network surgery.
Second request is usually user would want to load the check point then freeze the core parameters and just train the newly added FC.
Describe alternatives you've considered
Code should be refactored to allow to give the last FC a different name. with that the check point would be loaded without any errors
as in Project-MONAI/MONAILabel#1298 I had to write my own segresnet below to change the name of the last layer
class SegResNetNewFC(SegResNet):
def __init__(self,spatial_dims: int = 3,
init_filters: int = 8,out_channels: int = 2,
**kwargs,):
super().__init__(spatial_dims=spatial_dims , init_filters=init_filters
,use_conv_final=False,**kwargs)
self.conv_final=None # delete this layer
self.conv_final_xx = self._make_final_conv(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
beforeFC = super().forward(x) ## will build everything without the last FC since we init with use_conv_final =False
x = self.conv_final_xx(beforeFC)
return x
Additional context
We do have code to frezze the layers but it needs to be a utility that calls into a function of the network to get FC layer name. For this we should have an interface for network models to implement so this util function can be called
I have it hacked as
def optimizer(self, context: Context):
freezeLayers(self._network, ["conv_final_xx"], True)
context.network = self._network
return torch.optim.AdamW(context.network.parameters(), lr=1e-4, weight_decay=1e-5) ## too small
def freezeLayers(network, exclude_vars, exclude_include=True):
src_dict = get_state_dict(network)
to_skip={''} # initalize single string dict
for exclude_var in exclude_vars:
s = {s_key for s_key in src_dict if exclude_var and re.compile(exclude_var).search(s_key)}
to_skip.update(s)
to_skip.remove('') ## remove the empty string
logger.error(f"--------------------- layer freeze with {exclude_include=}")
for name, para in network.named_parameters():
if exclude_include:
if name not in to_skip:
logger.info(f"------------ freezing {name} size {para.shape}")
para.requires_grad = False
else:
logger.info(f"----training ------- {name} size {para.shape}")
else:
if name in to_skip:
logger.info(f"------------ freezing {name} size {para.shape}")
para.requires_grad = False
else:
logger.info(f"----training ------- {name} size {para.shape}")