Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
philipperemy committed Aug 22, 2020
1 parent 37c7132 commit e578241
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
2 changes: 0 additions & 2 deletions examples/model_in_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from tensorflow.keras.utils import plot_model

import keract

# gradients requires no eager execution.
import utils

tf.compat.v1.disable_eager_execution()
Expand Down
34 changes: 26 additions & 8 deletions keract/keract.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,22 @@ def n_(node, output_format_, nested=False):

def _evaluate(model: Model, nodes_to_evaluate, x, y=None, auto_compile=False):
if not model._is_compiled:
if model.name in ['vgg16', 'vgg19', 'inception_v3', 'inception_resnet_v2', 'mobilenet_v2', 'mobilenetv2']:
# tensorflow.python.keras.applications.*
applications_model_names = [
'densenet',
'efficientnet',
'inception_resnet_v2',
'inception_v3',
'mobilenet',
'mobilenet_v2',
'nasnet',
'resnet',
'resnet_v2',
'vgg16',
'vgg19',
'xception'
]
if model.name in applications_model_names:
print('Transfer learning detected. Model will be compiled with ("categorical_crossentropy", "adam").')
print('If you want to change the default behaviour, then do in python:')
print('model.name = ""')
Expand All @@ -75,6 +90,11 @@ def eval_fn(k_inputs):
if y is None: # tf 2.3.0 upgrade compatibility.
return K.function(k_inputs, nodes_to_evaluate)(x)
return K.function(k_inputs, nodes_to_evaluate)((x, y)) # although works.
except ValueError as e:
print('Run it without eager mode. Paste those commands at the beginning of your script:')
print('> import tensorflow as tf')
print('> tf.compat.v1.disable_eager_execution()')
raise e

try:
return eval_fn(model._feed_inputs + model._feed_targets + model._feed_sample_weights)
Expand Down Expand Up @@ -112,7 +132,7 @@ def get_gradients_of_activations(model, x, y, layer_names=None, output_format='s
- 'full': output key will match the full name of the output layer name. In the example above, it will
return {'d1/BiasAdd:0': ...}.
- 'numbered': output key will be an index range, based on the order of definition of each layer within the model.
- 'nested': If specified, will move recursively through the model definition to retrieve nested layers.
:param nested: (optional) If set, will move recursively through the model definition to retrieve nested layers.
Recursion ends at leaf layers of the model tree or at layers with their name specified in layer_names.
E.g., a model with the following structure
Expand Down Expand Up @@ -158,7 +178,7 @@ def _get_gradients(model, x, y, nodes):
differentiable_nodes.append(n)
except ValueError:
pass
nodes_values = differentiable_nodes
# nodes_values = differentiable_nodes
else:
raise e

Expand Down Expand Up @@ -206,15 +226,15 @@ def output(u):
return node_dict

elif bool(layer_names) and module_name in layer_names:
print("1", module_name, module)
# print("1", module_name, module)
return OrderedDict({module_name: module.output})

elif not bool(layer_names):
print("2", module_name, module)
# print("2", module_name, module)
return OrderedDict({module_name: module.output})

else:
print("3", module_name, module)
# print("3", module_name, module)
return OrderedDict()


Expand Down Expand Up @@ -409,7 +429,6 @@ def display_heatmaps(activations, input_image, directory='.', save=False, fix=Tr
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import math

data_format = K.image_data_format()
Expand Down Expand Up @@ -557,7 +576,6 @@ def load_activations_from_json_file(filename):
:param filename: filename to read the activations from (JSON format)
:return: activations (dict mapping layers)
"""
import numpy as np
with open(filename, 'r') as r:
d = json.load(r, object_pairs_hook=OrderedDict)
activations = OrderedDict({k: np.array(v) for k, v in d.items()})
Expand Down

0 comments on commit e578241

Please sign in to comment.