Skip to content

Commit

Permalink
[Keras] Add l2_normalize support (#9383)
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored Nov 3, 2021
1 parent 8ada2b1 commit bff9884
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ perf
.bash_history
*.json
*.params
*.ro
*.onnx
*.h5
synset.txt
Expand Down Expand Up @@ -240,4 +241,4 @@ conda/pkg
# Downloaded models/datasets
.tvm_test_data
.dgl
.caffe2
.caffe2
104 changes: 103 additions & 1 deletion python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name, import-self, import-outside-toplevel
"""Keras frontend."""
import dis
import sys
import numpy as np
import tvm
Expand Down Expand Up @@ -988,10 +989,110 @@ def _convert_repeat_vector(inexpr, keras_layer, _):
out_shape = [-1, repeats] + input_shape[1:]
out = _op.repeat(inexpr, repeats=repeats, axis=0)
out = _op.reshape(out, out_shape)

return out


def _convert_l2_normalize(inexpr, keras_layer, etab):
l2_normalize_is_loaded = False
param_list = []
for i in dis.get_instructions(keras_layer.function):
if i.opname in ["LOAD_GLOBAL", "LOAD_DEREF"]:
continue
if i.opname in ["LOAD_ATTR", "LOAD_METHOD"]:
if i.argval == "l2_normalize":
assert not l2_normalize_is_loaded, "l2_normalize was already LOADED"
l2_normalize_is_loaded = True
elif i.opname in ["LOAD_CONST", "LOAD_FAST"] and l2_normalize_is_loaded:
param_list.append(i.argval)
elif i.opname == "BUILD_LIST":
sz = i.argval
assert len(param_list) >= sz
new_list = param_list[-sz:]
param_list = param_list[:-sz]
param_list.append(new_list)
elif i.opname in ["CALL_FUNCTION_KW", "CALL_METHOD"]:
break

axis = None
is_param_list_parsed = False
if l2_normalize_is_loaded and len(param_list) > 0:
# last param_list item is tuple of strings means that
# lambda uses named parameters when calling l2_normalize
if (
isinstance(param_list[-1], tuple)
and len(param_list[-1]) > 0
and isinstance(param_list[-1][0], str)
):
param_names = param_list[-1]
if len(param_names) == 1 and param_names[0] == "x":
# lambda v: K.l2_normalize(x=v)
axis = None
is_param_list_parsed = True
elif len(param_names) == 1 and param_names[0] == "axis" and len(param_list) == 3:
# lambda x: K.l2_normalize(x, axis=(2,3))
axis = param_list[1]
is_param_list_parsed = True
elif len(param_names) == 2 and len(param_list) == 3:
# lambda x: K.l2_normalize(x=x, axis=(2,3))
# lambda x: K.l2_normalize(axis=(2,3), x=x)
axis = param_list[param_names.index("axis")]
is_param_list_parsed = True
else:
# lambda x: K.l2_normalize(x)
if len(param_list) == 1:
axis = None
is_param_list_parsed = True
# lambda x: K.l2_normalize(x, (2,3))
elif len(param_list) == 2:
axis = param_list[1]
is_param_list_parsed = True

def is_int_or_tuple_of_ints(v):
if isinstance(v, list) and len(v) > 0:
for i in v:
if not isinstance(i, int):
return False
return True
if isinstance(v, tuple) and len(v) > 0:
return isinstance(v[0], int)
return isinstance(v, int)

assert is_param_list_parsed and (
axis is None or is_int_or_tuple_of_ints(axis)
), "Can not parse l2_normalize lambda function found in Lambda layer"
if isinstance(axis, int):
axis = [axis]

if etab.data_layout == "NCHW":
dims = len(keras_layer.input_shape)

def fix_axis_for_nchw(axis):
if axis == 0:
return 0
if axis in [(dims - 1), -1]:
return 1
return axis + 1

axis = [fix_axis_for_nchw(x) for x in axis]
return _op.nn.l2_normalize(inexpr, eps=1e-12, axis=axis)


def _convert_lambda(inexpr, keras_layer, etab):
fcode = keras_layer.function.__code__
# Convert l2_normalize
if (
fcode.co_name == "<lambda>"
and len(fcode.co_names) > 0
and fcode.co_names[-1] == "l2_normalize"
):
return _convert_l2_normalize(inexpr, keras_layer, etab)
raise tvm.error.OpNotImplemented(
"Function {} used in Lambda layer is not supported in frontend Keras.".format(
fcode.co_names
)
)


def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
"""Layers that can be skipped because they are train time only."""
return inexpr
Expand Down Expand Up @@ -1056,6 +1157,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
"Permute": _convert_permute,
"Embedding": _convert_embedding,
"RepeatVector": _convert_repeat_vector,
"Lambda": _convert_lambda,
"InputLayer": _default_skip,
"Dropout": _default_skip,
"AlphaDropout": _default_skip,
Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,26 @@ def test_forward_nested_layers(self, keras):
)
verify_keras_frontend(keras_model)

def test_forward_l2_normalize(self, keras):
data = keras.layers.Input(shape=(16, 12, 8))
K = keras.backend
l2_funcs = [
keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=-2)),
keras.layers.Lambda(lambda v: K.l2_normalize(x=v, axis=-1)),
keras.layers.Lambda(lambda v: K.l2_normalize(axis=1, x=v)),
keras.layers.Lambda(lambda v: K.l2_normalize(v, 2)),
keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=3)),
keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=(2, 3))),
keras.layers.Lambda(lambda v: K.l2_normalize(v, (1, 2))),
keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=[-2, -1])),
keras.layers.Lambda(lambda v: K.l2_normalize(v, [-3, -2])),
]
for l2_func in l2_funcs:
x = l2_func(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, layout="NCHW")
verify_keras_frontend(keras_model, layout="NHWC")


if __name__ == "__main__":
for k in [keras, tf_keras]:
Expand Down Expand Up @@ -641,3 +661,4 @@ def test_forward_nested_layers(self, keras):
sut.test_forward_zero_padding3d(keras=k)
sut.test_forward_embedding(keras=k)
sut.test_forward_repeat_vector(keras=k)
sut.test_forward_l2_normalize(keras=k)

0 comments on commit bff9884

Please sign in to comment.