Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request]: Support Model Surgery #262

Closed
innat opened this issue Feb 15, 2023 · 11 comments
Closed

[Feature Request]: Support Model Surgery #262

innat opened this issue Feb 15, 2023 · 11 comments
Assignees

Comments

@innat
Copy link

innat commented Feb 15, 2023

TensorFlow version (you are using): 2.9.2
Are you willing to contribute it (Yes/No) :

Describe the feature and the current behavior/state.

This request originated from this old post. In short, if I've a model initially as follows, I like to able to update the model as such:

  • delete layers
  • insert layers
  • replace layers

For example, in the following keras_simple_model:

  1. replace Con2D layer with the same but without bias.
  2. or add BatchNormalization layer before first Activation
def keras_simple_model():
    inputs1 = Input((28, 28, 1))
    x = Conv2D(4, (3, 3), activation=None, padding='same', name='conv1')(inputs1)
    x = Activation('relu')(x)
    x = Conv2D(4, (3, 3), activation=None, padding='same', name='conv2')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(x)

    x = Conv2D(8, (3, 3), activation=None, padding='same', name='conv3')(x)
    x = Activation('relu')(x)
    x = Conv2D(8, (3, 3), activation=None, padding='same', name='conv4')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(x)

    x = GlobalAveragePooling2D()(x)
    x = Dense(10, activation=None)(x)
    x = Activation('softmax')(x)

    model = Model(inputs=inputs1, outputs=x)
    return model

Will this change the current api? How?

I think it is more suited to tf.keras.utils.

Who will benefit from this feature?

  1. I've faced an issue, where I used augmentaiton layers inside the model. These layers are inactive during the model.evaluate and model.predict phase. But for test-time-augmentation, I like them to active but that wasn't possible with ease. So, if utility like model surgery is supported, then we can use it inside the callback API to update the target layer to behave differently, model.predict(tensor, callback=[update]). With details,
class TrainingAugmentationLayers(keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.random_flip = RandomFlip("horizontal")
        self.random_zoom = RandomZoom(0.1, fill_mode='reflect')
        self.random_rotate = RandomRotation(0.1, fill_mode='reflect')
    def call(self, inputs):
        x = self.random_flip(inputs)
        x = self.random_zoom(x)
        x = self.random_rotate(x)
        return x

class PredictAugmentaitonLayers(TrainingAugmentationLayers):
    def call(self, inputs):
        x = self.random_flip(inputs, training=True)
        x = self.random_zoom(x, training=True)
        x = self.random_rotate(x, training=True)
        return x

img = tf.random.normal([20, 512, 512, 3], 0, 1, tf.float32)
y_true = img

inputs = tf.keras.Input(shape=(512, 512, 3))
x = TrainingAugmentationLayers()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=x)
class UpdateCallback(keras.callbacks.Callback):
    __init__:
         target_layer_id 
         new_layer

    on_predict_begin:
         keras.utils.update_model(
             self.model, 
             target_layer_id
             new_layer
         )

model.predict(
    tensor, 
    callbakcs=[
        UpdateCallback(
            target_layer_id=0, new_layer=PredictAugmentaitonLayers
        )
    ]
)
  1. There are some open-source support and quite popular still now. Please check:
  1. Popular questions/answeres on SO: https://stackoverflow.com/q/49492255/9215780

Contributing

  • Do you want to contribute a PR? (yes/no): Yes (but need some design-level suggestion for all modeling API).
  • If yes, please read this page for instructions
  • Briefly describe your candidate solution(if contributing):
@nkovela1
Copy link
Contributor

nkovela1 commented Feb 16, 2023

Hi @innat, there are many existing approaches that can be used for inserting, deleting, and replacing layers in such a fashion. For example, model slicing is already a supported feature.

before slicing
after slicing

The original model had two parallel branches and this model slicing code cuts one branch out and discards it. This model slicing functionality can be extended to the use cases you have mentioned.


# get one of the branches
branch1 = model.get_layer("branch1")

# get the tail
tail = model.get_layer("tail")

# splice branch1 and tail together (discarding branch2
# as well as the Add op)

x = branch1.input

y = branch1(x)
y = tail(y)
model3 = tf.keras.Model(x,y)```

@innat
Copy link
Author

innat commented Feb 16, 2023

Please give some moment to interact. What's so hurry!
It took time to write down all those words. 😅

Let me check, and get back to you.

@nkovela1 nkovela1 reopened this Feb 16, 2023
@nkovela1
Copy link
Contributor

nkovela1 commented Feb 16, 2023

@innat Sorry, accidentally hit close on it 😅 my bad! Let me know if this helps :)

@innat
Copy link
Author

innat commented Feb 17, 2023

Hi @nkovela1 Thanks for the response.
The demonstration you made is doable because of its simplicity. But things can get messy if you try to remove (or replace) conv1.2 and conv2.2 from two brach and continue. Please check this questions/answers/linkedQnA, or (this issue).

Not to mention, there are total three type of modeling API (seq, func, sub), so having a generic high level method from API can serve a great role here. Let me know what do you think.

Related (keras-team/keras#16355)

@innat
Copy link
Author

innat commented Feb 17, 2023

Mentioning @ZFTurbo @BenWhetton @leondgarse to give some inputs.

@Frightera
Copy link
Contributor

@innat +1. This will be handy for complex models, where there are multiple layers and branches that need to be modified or removed. In such cases, manual surgery can be time-consuming and error-prone, and a high-level method within the Keras API would simplify the process and make it more accessible to users.

@leondgarse
Copy link
Contributor

leondgarse commented Feb 22, 2023

  • My first thought is using keras.models.clone_model for replacing layer, and replacing by keras.layers.Activation(activation="linear") for deleting, module like keras.models.Sequential([source_layer, target_layer]) for inserting. It works in most cases, just the generated model not very neat...
  • This is another basic impementation on vanilla resnet50. Inspired by this stackoverflow answer how-to-replace-or-insert-intermediate-layer-in-keras-model. The key diference from keras.models.clone_model is that, the clone_function takes 2 arguments layer, inputs. Thus we can do more things within it, as long as returning a tensor. Still many errors within this...
import tensorflow as tf
from tensorflow import keras

def clone_model(model, clone_function=None):
    # Run through the model forward pipeline
    inner_tensors = {ii.name: ii for ii in model.inputs}
    for layer in model.layers[1:]:
        # print(layer.name)
        if isinstance(layer.input, list):
            inputs = [inner_tensors[ii.name] for ii in layer.input]
        else:
            inputs = inner_tensors[layer.input.name]

        # Key function calling clone_function for replacing / deleting / inserting layer
        out = layer(inputs) if clone_function is None else clone_function(layer, inputs)
        inner_tensors[layer.output.name] = out
    dd = keras.models.Model(model.inputs, [inner_tensors[ii.name] for ii in model.outputs])
    # Have to create a new model, or the delete operation is not actually applied.
    # May need some operations on `inbound_nodes` and `outbound_nodes`. [???]
    # Also need other investigation for training model.
    dd = keras.models.model_from_json(dd.to_json())

    # Reload weights, or:
    # model.save_weights('aa.h5')
    # dd.load_weights('aa.h5', by_name=True, skip_mismatch=True)
    orign_layers = [ii.name for ii in model.layers]
    for ii in dd.layers:
        if ii.name in orign_layers:
            # print(ii.name)
            ii.set_weights(model.get_layer(ii.name).get_weights())
    return dd

Basic tests:

mm = keras.applications.ResNet50(weights=None)
print(f"{mm.count_params() = }")
# mm.count_params() = 25636712

""" Delete all batchnorm layers """
delete_bn = lambda layer, inputs: inputs if isinstance(layer, keras.layers.BatchNormalization) else layer(inputs)
dd = clone_model(mm, clone_function=delete_bn)
print(f"{dd.count_params() = }")
# dd.count_params() = 25530472

""" Insert `relu + DepthwiseConv2D + BatchNormalization` after each BatchNormalization layer """
def insert_layers_function(layer, inputs):
    if isinstance(layer, keras.layers.BatchNormalization):
        pre = layer(inputs)
        nn = tf.nn.relu(pre)
        nn = keras.layers.DepthwiseConv2D(kernel_size=1, name=layer.name + "_depth_conv")(nn)
        nn = keras.layers.BatchNormalization(name=layer.name + "_depth_conv_bn")(nn)
        return keras.layers.Add()([nn, pre])  # nn + pre not working well [???]
    return layer(inputs)
dd = clone_model(mm, clone_function=insert_layers_function)
print(f"{dd.count_params() = }")
# dd.count_params() = 25796072

""" Delete all batchnorm layers on new model,
this one will throw error if not creating new model using `model_from_json`
"""
ee = clone_model(dd, clone_function=delete_bn)
print(f"{ee.count_params() = }")
# ee.count_params() = 25583592

@Frightera
Copy link
Contributor

@leondgarse This is a good start, one main thing to consider is having packed sequential models, something like expand_nested=True.

@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@innat
Copy link
Author

innat commented Mar 2, 2023

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

😑

@Frightera
Copy link
Contributor

Frightera commented Mar 13, 2023

@innat I may start working on this if you haven't. This would be a handy feature.

Edit: This is in progress but can not give an ETA now.

Edit2 (05.04.23): I'll be working on this actively, this is half complete now.

Edit3: This needs more work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants