From a40de10fa4f2923f53003627d0312337e081ec69 Mon Sep 17 00:00:00 2001 From: Paul Haase Date: Tue, 24 Jan 2023 15:06:15 +0100 Subject: [PATCH] Aligned requirements.txt with requirements_cu11.txt and updated handling of tensorflow state dictionaries such that it compatible with the current version of the "h5py" package. --- README.md | 2 +- framework/tensorflow_model/__init__.py | 20 ++++++++++---------- requirements.txt | 10 +++++----- setup.py | 2 +- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 16c7214..d4f5127 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ source env/bin/activate **Note**: For further information on how to set up a virtual python environment (also on **Windows**) refer to https://docs.python.org/3/library/venv.html . -When successfully installed, the software outputs the line : "Successfully installed NNC-0.1.6" +When successfully installed, the software outputs the line : "Successfully installed NNC-0.1.7" ### Importing the main module diff --git a/framework/tensorflow_model/__init__.py b/framework/tensorflow_model/__init__.py index b9cabd8..04cec06 100644 --- a/framework/tensorflow_model/__init__.py +++ b/framework/tensorflow_model/__init__.py @@ -62,9 +62,9 @@ def save_to_tensorflow_file( model_data, path ): grp_names = [] for module_name in model_data: splits = module_name.split('/') - grp_name = module_name.split('/')[0].encode('utf8') + grp_name = module_name.split('/')[0] if splits[0] == splits[2]: - grp_name = (splits[0] + '/' + splits[1]).encode('utf8') + grp_name = (splits[0] + '/' + splits[1]) if grp_name not in grp_names: grp_names.append(grp_name) if model_data[module_name].size != 1: @@ -75,14 +75,14 @@ def save_to_tensorflow_file( model_data, path ): for grp in h5_model: weight_attr = [] - if isinstance(h5_model[grp], h5py.Group) and grp.encode('utf8') in grp_names: - weight_attr = [k[len(grp)+1:].encode('utf8') for k, v in model_data.items() + if isinstance(h5_model[grp], h5py.Group) and grp in grp_names: + weight_attr = [k[len(grp)+1:] for k, v in model_data.items() if k.startswith(grp+'/')] h5_model[grp].attrs['weight_names'] = weight_attr elif isinstance(h5_model[grp], h5py.Group): for subgrp in h5_model[grp]: - if isinstance(h5_model[grp], h5py.Group) and (grp + '/' + subgrp).encode('utf8') in grp_names: - weight_attr = [k[len(grp) + len(subgrp) + 2:].encode('utf8') for k, v in model_data.items() + if isinstance(h5_model[grp], h5py.Group) and (grp + '/' + subgrp) in grp_names: + weight_attr = [k[len(grp) + len(subgrp) + 2:] for k, v in model_data.items() if k.startswith(grp + '/' + subgrp + '/')] h5_model[grp + '/' + subgrp].attrs['weight_names'] = weight_attr @@ -243,13 +243,13 @@ def load_model( self, return self.init_model_from_model_object(model_file) else: if 'layer_names' in model_file.attrs: - module_names = [n.decode('utf8') for n in model_file.attrs['layer_names']] + module_names = [n for n in model_file.attrs['layer_names']] layer_names = [] for mod_name in module_names: layer = model_file[mod_name] if 'weight_names' in layer.attrs: - weight_names = [mod_name+'/'+n.decode('utf8') for n in layer.attrs['weight_names']] + weight_names = [mod_name+'/'+n for n in layer.attrs['weight_names']] if weight_names: layer_names += weight_names @@ -273,13 +273,13 @@ def init_model_from_model_object( self, os.remove(h5_model_path) if 'layer_names' in model.attrs: - module_names = [n.decode('utf8') for n in model.attrs['layer_names']] + module_names = [n for n in model.attrs['layer_names']] layer_names = [] for mod_name in module_names: layer = model[mod_name] if 'weight_names' in layer.attrs: - weight_names = [mod_name+'/'+n.decode('utf8') for n in layer.attrs['weight_names']] + weight_names = [mod_name+'/'+n for n in layer.attrs['weight_names']] if weight_names: layer_names += weight_names diff --git a/requirements.txt b/requirements.txt index 48169ca..fef2a27 100755 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,10 @@ Click>=7.0 scikit-learn>=0.23.1 tqdm>=4.32.2 -h5py>=2.10.0 -pybind11>=2.5.0 -torch>=1.7.1 -torchvision>=0.8.2 -tensorflow>=2.3.1 +h5py>=3.1.0 +pybind11>=2.6.2 +torch>=1.8.1 +torchvision>=0.9.1 +tensorflow>=2.6.0 pandas>=1.0.5 opencv-python>=4.4.0.46 diff --git a/setup.py b/setup.py index 162ee4e..7018b09 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ from setuptools.command.build_ext import build_ext import setuptools -__version__ = '0.1.6' +__version__ = '0.1.7' class get_pybind_include(object):