-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request]: Support Model Surgery #262
Comments
Hi @innat, there are many existing approaches that can be used for inserting, deleting, and replacing layers in such a fashion. For example, model slicing is already a supported feature. The original model had two parallel branches and this model slicing code cuts one branch out and discards it. This model slicing functionality can be extended to the use cases you have mentioned.
|
Please give some moment to interact. What's so hurry! Let me check, and get back to you. |
@innat Sorry, accidentally hit close on it 😅 my bad! Let me know if this helps :) |
Hi @nkovela1 Thanks for the response. Not to mention, there are total three type of modeling API (seq, func, sub), so having a generic high level method from API can serve a great role here. Let me know what do you think. Related (keras-team/keras#16355) |
Mentioning @ZFTurbo @BenWhetton @leondgarse to give some inputs. |
@innat +1. This will be handy for complex models, where there are multiple layers and branches that need to be modified or removed. In such cases, manual surgery can be time-consuming and error-prone, and a high-level method within the Keras API would simplify the process and make it more accessible to users. |
import tensorflow as tf
from tensorflow import keras
def clone_model(model, clone_function=None):
# Run through the model forward pipeline
inner_tensors = {ii.name: ii for ii in model.inputs}
for layer in model.layers[1:]:
# print(layer.name)
if isinstance(layer.input, list):
inputs = [inner_tensors[ii.name] for ii in layer.input]
else:
inputs = inner_tensors[layer.input.name]
# Key function calling clone_function for replacing / deleting / inserting layer
out = layer(inputs) if clone_function is None else clone_function(layer, inputs)
inner_tensors[layer.output.name] = out
dd = keras.models.Model(model.inputs, [inner_tensors[ii.name] for ii in model.outputs])
# Have to create a new model, or the delete operation is not actually applied.
# May need some operations on `inbound_nodes` and `outbound_nodes`. [???]
# Also need other investigation for training model.
dd = keras.models.model_from_json(dd.to_json())
# Reload weights, or:
# model.save_weights('aa.h5')
# dd.load_weights('aa.h5', by_name=True, skip_mismatch=True)
orign_layers = [ii.name for ii in model.layers]
for ii in dd.layers:
if ii.name in orign_layers:
# print(ii.name)
ii.set_weights(model.get_layer(ii.name).get_weights())
return dd Basic tests: mm = keras.applications.ResNet50(weights=None)
print(f"{mm.count_params() = }")
# mm.count_params() = 25636712
""" Delete all batchnorm layers """
delete_bn = lambda layer, inputs: inputs if isinstance(layer, keras.layers.BatchNormalization) else layer(inputs)
dd = clone_model(mm, clone_function=delete_bn)
print(f"{dd.count_params() = }")
# dd.count_params() = 25530472
""" Insert `relu + DepthwiseConv2D + BatchNormalization` after each BatchNormalization layer """
def insert_layers_function(layer, inputs):
if isinstance(layer, keras.layers.BatchNormalization):
pre = layer(inputs)
nn = tf.nn.relu(pre)
nn = keras.layers.DepthwiseConv2D(kernel_size=1, name=layer.name + "_depth_conv")(nn)
nn = keras.layers.BatchNormalization(name=layer.name + "_depth_conv_bn")(nn)
return keras.layers.Add()([nn, pre]) # nn + pre not working well [???]
return layer(inputs)
dd = clone_model(mm, clone_function=insert_layers_function)
print(f"{dd.count_params() = }")
# dd.count_params() = 25796072
""" Delete all batchnorm layers on new model,
this one will throw error if not creating new model using `model_from_json`
"""
ee = clone_model(dd, clone_function=delete_bn)
print(f"{ee.count_params() = }")
# ee.count_params() = 25583592 |
@leondgarse This is a good start, one main thing to consider is having packed sequential models, something like |
This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you. |
😑 |
@innat I may start working on this if you haven't. This would be a handy feature. Edit: This is in progress but can not give an ETA now. Edit2 (05.04.23): I'll be working on this actively, this is half complete now. Edit3: This needs more work. |
TensorFlow version (you are using): 2.9.2
Are you willing to contribute it (Yes/No) :
Describe the feature and the current behavior/state.
This request originated from this old post. In short, if I've a model initially as follows, I like to able to update the model as such:
For example, in the following
keras_simple_model
:Con2D
layer with the same but without bias.BatchNormalization
layer before first ActivationWill this change the current api? How?
I think it is more suited to
tf.keras.utils
.Who will benefit from this feature?
model.evaluate
andmodel.predict
phase. But for test-time-augmentation, I like them to active but that wasn't possible with ease. So, if utility like model surgery is supported, then we can use it inside the callback API to update the target layer to behave differently,model.predict(tensor, callback=[update])
. With details,Contributing
The text was updated successfully, but these errors were encountered: