Skip to content

Commit

Permalink
Aligned requirements.txt with requirements_cu11.txt and updated handl…
Browse files Browse the repository at this point in the history
…ing of tensorflow state dictionaries such that it compatible with the current version of the "h5py" package.
  • Loading branch information
phaase-hhi committed Jan 24, 2023
1 parent 13a2fce commit a40de10
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ source env/bin/activate

**Note**: For further information on how to set up a virtual python environment (also on **Windows**) refer to https://docs.python.org/3/library/venv.html .

When successfully installed, the software outputs the line : "Successfully installed NNC-0.1.6"
When successfully installed, the software outputs the line : "Successfully installed NNC-0.1.7"

### Importing the main module

Expand Down
20 changes: 10 additions & 10 deletions framework/tensorflow_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def save_to_tensorflow_file( model_data, path ):
grp_names = []
for module_name in model_data:
splits = module_name.split('/')
grp_name = module_name.split('/')[0].encode('utf8')
grp_name = module_name.split('/')[0]
if splits[0] == splits[2]:
grp_name = (splits[0] + '/' + splits[1]).encode('utf8')
grp_name = (splits[0] + '/' + splits[1])
if grp_name not in grp_names:
grp_names.append(grp_name)
if model_data[module_name].size != 1:
Expand All @@ -75,14 +75,14 @@ def save_to_tensorflow_file( model_data, path ):

for grp in h5_model:
weight_attr = []
if isinstance(h5_model[grp], h5py.Group) and grp.encode('utf8') in grp_names:
weight_attr = [k[len(grp)+1:].encode('utf8') for k, v in model_data.items()
if isinstance(h5_model[grp], h5py.Group) and grp in grp_names:
weight_attr = [k[len(grp)+1:] for k, v in model_data.items()
if k.startswith(grp+'/')]
h5_model[grp].attrs['weight_names'] = weight_attr
elif isinstance(h5_model[grp], h5py.Group):
for subgrp in h5_model[grp]:
if isinstance(h5_model[grp], h5py.Group) and (grp + '/' + subgrp).encode('utf8') in grp_names:
weight_attr = [k[len(grp) + len(subgrp) + 2:].encode('utf8') for k, v in model_data.items()
if isinstance(h5_model[grp], h5py.Group) and (grp + '/' + subgrp) in grp_names:
weight_attr = [k[len(grp) + len(subgrp) + 2:] for k, v in model_data.items()
if k.startswith(grp + '/' + subgrp + '/')]
h5_model[grp + '/' + subgrp].attrs['weight_names'] = weight_attr

Expand Down Expand Up @@ -243,13 +243,13 @@ def load_model( self,
return self.init_model_from_model_object(model_file)
else:
if 'layer_names' in model_file.attrs:
module_names = [n.decode('utf8') for n in model_file.attrs['layer_names']]
module_names = [n for n in model_file.attrs['layer_names']]

layer_names = []
for mod_name in module_names:
layer = model_file[mod_name]
if 'weight_names' in layer.attrs:
weight_names = [mod_name+'/'+n.decode('utf8') for n in layer.attrs['weight_names']]
weight_names = [mod_name+'/'+n for n in layer.attrs['weight_names']]
if weight_names:
layer_names += weight_names

Expand All @@ -273,13 +273,13 @@ def init_model_from_model_object( self,
os.remove(h5_model_path)

if 'layer_names' in model.attrs:
module_names = [n.decode('utf8') for n in model.attrs['layer_names']]
module_names = [n for n in model.attrs['layer_names']]

layer_names = []
for mod_name in module_names:
layer = model[mod_name]
if 'weight_names' in layer.attrs:
weight_names = [mod_name+'/'+n.decode('utf8') for n in layer.attrs['weight_names']]
weight_names = [mod_name+'/'+n for n in layer.attrs['weight_names']]
if weight_names:
layer_names += weight_names

Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
Click>=7.0
scikit-learn>=0.23.1
tqdm>=4.32.2
h5py>=2.10.0
pybind11>=2.5.0
torch>=1.7.1
torchvision>=0.8.2
tensorflow>=2.3.1
h5py>=3.1.0
pybind11>=2.6.2
torch>=1.8.1
torchvision>=0.9.1
tensorflow>=2.6.0
pandas>=1.0.5
opencv-python>=4.4.0.46
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from setuptools.command.build_ext import build_ext
import setuptools

__version__ = '0.1.6'
__version__ = '0.1.7'


class get_pybind_include(object):
Expand Down

0 comments on commit a40de10

Please sign in to comment.