From 13a2fce4bb8df832cc890bb9c50a4703c4c3f2fb Mon Sep 17 00:00:00 2001 From: Paul Haase Date: Tue, 17 Jan 2023 13:31:50 +0100 Subject: [PATCH] Bugfix for compression of arbitrary numpy dictionaries with 'compress': - The encoder crashed during encoding of arbitrary numpy dictionaries with the 'compress' function. This was caused by missing entries for the parameter types in the model_information dictionary. --- README.md | 2 +- nnc_core/nnr_model/__init__.py | 10 +++++----- setup.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index efe216a..16c7214 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ source env/bin/activate **Note**: For further information on how to set up a virtual python environment (also on **Windows**) refer to https://docs.python.org/3/library/venv.html . -When successfully installed, the software outputs the line : "Successfully installed NNC-0.1.5" +When successfully installed, the software outputs the line : "Successfully installed NNC-0.1.6" ### Importing the main module diff --git a/nnc_core/nnr_model/__init__.py b/nnc_core/nnr_model/__init__.py index b50bafe..e90a669 100644 --- a/nnc_core/nnr_model/__init__.py +++ b/nnc_core/nnr_model/__init__.py @@ -165,12 +165,12 @@ def init_model_from_dict(self, model_dict): model_info['parameter_dimensions'][module_name] = np.array([0]).shape model_info['parameter_index'][module_name] = i - dims = len(mdl_shape) + dims = len(mdl_shape) - if dims > 1: - model_info['parameter_type'][module_name] = 'weight' - else: - model_info['parameter_type'][module_name] = 'unspecified' + if dims > 1: + model_info['parameter_type'][module_name] = 'weight' + else: + model_info['parameter_type'][module_name] = 'unspecified' model_info['topology_storage_format'] = nnc_core.nnr_model.TopologyStorageFormat.NNR_TPL_UNREC model_info['topology_compression_format'] = nnc_core.nnr_model.TopologyCompressionFormat.NNR_PT_RAW diff --git a/setup.py b/setup.py index 4f7bf06..162ee4e 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ from setuptools.command.build_ext import build_ext import setuptools -__version__ = '0.1.5' +__version__ = '0.1.6' class get_pybind_include(object):