diff --git a/tests/test_model.py b/tests/test_model.py index c7b885d..4a61c08 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -28,10 +28,11 @@ def torch2tf(torch_name): fields = torch_name.split('.') - element_id = int(fields[2]) + offset = int(fields[2] == "networks") + element_id = int(fields[2 + offset]) if fields[0] == 'descriptor': - layer_id = int(fields[4]) + 1 - weight_type = fields[5] + layer_id = int(fields[4 + offset]) + 1 + weight_type = fields[5 + offset] return 'filter_type_all/%s_%d_%d:0' % (weight_type, layer_id, element_id) elif fields[3] == 'deep_layers': layer_id = int(fields[4])