diff --git a/delta/config/config.py b/delta/config/config.py index aba3b0ae..bbb2127f 100644 --- a/delta/config/config.py +++ b/delta/config/config.py @@ -125,14 +125,22 @@ def register_arg(self, field, argname, **kwargs): kwargs['default'] = _NotSpecified self._cmd_args[argname] = (field, kwargs) + def to_dict(self) -> dict: + """ + Returns a dictionary representing the config object. + """ + if isinstance(self._config_dict, dict): + exp = self._config_dict.copy() + for (name, c) in self._components.items(): + exp[name] = c.to_dict() + return exp + return self._config_dict + def export(self) -> str: """ - Returns a YAML string of all configuration options. + Returns a YAML string of all configuration options, from to_dict. """ - exp = self._config_dict.copy() - for (name, c) in self._components.items(): - exp[name] = c.export() - return yaml.dump(exp) + return yaml.dump(self.to_dict()) def _set_field(self, name : str, value : str, base_dir : str): if name not in self._fields: diff --git a/delta/imagery/imagery_config.py b/delta/imagery/imagery_config.py index 182fc2f0..dbf7a85d 100644 --- a/delta/imagery/imagery_config.py +++ b/delta/imagery/imagery_config.py @@ -252,6 +252,7 @@ def __len__(self): def _load_dict(self, d : dict, base_dir): if not d: return + self._config_dict = d self._classes = [] if isinstance(d, int): for i in range(d): diff --git a/delta/ml/io.py b/delta/ml/io.py new file mode 100644 index 00000000..12e4733f --- /dev/null +++ b/delta/ml/io.py @@ -0,0 +1,32 @@ +# Copyright © 2020, United States Government, as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All rights reserved. +# +# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is +# licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functions for IO specific to ML. +""" + +import h5py + +from delta.config import config + +def save_model(model, filename): + """ + Save a model. Includes DELTA configuration. + """ + model.save(filename, save_format='h5') + with h5py.File(filename, 'r+') as f: + f.attrs['delta'] = config.export() diff --git a/delta/ml/ml_config.py b/delta/ml/ml_config.py index 9eed4efb..10ea3bcf 100644 --- a/delta/ml/ml_config.py +++ b/delta/ml/ml_config.py @@ -106,25 +106,17 @@ def _load_dict(self, d : dict, base_dir): self._config_dict['layers'] = None elif 'layers' in d: self._config_dict['yaml_file'] = None - - def as_dict(self) -> dict: - """ - Returns a dictionary representing the network model for use by `delta.ml.model_parser`. - """ - yaml_file = self._config_dict['yaml_file'] - if yaml_file is not None: - if self._config_dict['layers'] is not None: - raise ValueError('Specified both yaml file and layers in model.') - + if 'yaml_file' in d and 'layers' in d and d['yaml_file'] is not None and d['layers'] is not None: + raise ValueError('Specified both yaml file and layers in model.') + if 'yaml_file' in d and d['yaml_file'] is not None: + yaml_file = d['yaml_file'] resource = os.path.join('config', yaml_file) if not os.path.exists(yaml_file) and pkg_resources.resource_exists('delta', resource): yaml_file = pkg_resources.resource_filename('delta', resource) if not os.path.exists(yaml_file): raise ValueError('Model yaml_file does not exist: ' + yaml_file) - #print('Opening model file: ' + yaml_file) with open(yaml_file, 'r') as f: - return yaml.safe_load(f) - return self._config_dict + self._config_dict.update(yaml.safe_load(f)) class NetworkConfig(config.DeltaConfigComponent): def __init__(self): @@ -166,7 +158,7 @@ def images(self) -> ImageSet: if self.__images is None: (self.__images, self.__labels) = load_images_labels(self._components['images'], self._components['labels'], - config.dataset.classes) + config.config.dataset.classes) return self.__images def labels(self) -> ImageSet: @@ -176,7 +168,7 @@ def labels(self) -> ImageSet: if self.__labels is None: (self.__images, self.__labels) = load_images_labels(self._components['images'], self._components['labels'], - config.dataset.classes) + config.config.dataset.classes) return self.__labels class TrainingConfig(config.DeltaConfigComponent): diff --git a/delta/ml/model_parser.py b/delta/ml/model_parser.py index ff9c15d6..3ba37ff3 100644 --- a/delta/ml/model_parser.py +++ b/delta/ml/model_parser.py @@ -154,4 +154,4 @@ def config_model(num_bands: int) -> Callable[[], tensorflow.keras.models.Sequent 'in_dims' : in_data_shape[0] * in_data_shape[1] * in_data_shape[2], 'num_bands' : in_data_shape[2]} - return model_from_dict(config.train.network.model.as_dict(), params_exposed) + return model_from_dict(config.train.network.model.to_dict(), params_exposed) diff --git a/delta/ml/train.py b/delta/ml/train.py index 6d31f96d..38dbf56a 100644 --- a/delta/ml/train.py +++ b/delta/ml/train.py @@ -31,6 +31,7 @@ from delta.imagery.imagery_dataset import ImageryDataset from delta.imagery.imagery_dataset import AutoencoderDataset from .layers import DeltaLayer +from .io import save_model def _devices(num_gpus): ''' @@ -132,7 +133,7 @@ def on_train_batch_end(self, batch, logs=None): mlflow.log_metric(k, logs[k], step=batch) if config.mlflow.checkpoints.frequency() and batch % config.mlflow.checkpoints.frequency() == 0: filename = os.path.join(self.temp_dir, '%d.h5' % (batch)) - self.model.save(filename, save_format='h5') + save_model(self.model, filename) if config.mlflow.checkpoints.only_save_latest(): old = filename filename = os.path.join(self.temp_dir, 'latest.h5') @@ -228,7 +229,7 @@ def train(model_fn, dataset : ImageryDataset, training_spec): if config.mlflow.enabled(): model_path = os.path.join(mcb.temp_dir, 'final_model.h5') print('\nFinished, saving model to %s.' % (mlflow.get_artifact_uri() + '/final_model.h5')) - model.save(model_path, save_format='h5') + save_model(model, model_path) mlflow.log_artifact(model_path) os.remove(model_path) mlflow.log_param('Status', 'Completed') @@ -238,7 +239,7 @@ def train(model_fn, dataset : ImageryDataset, training_spec): mlflow.end_run('FAILED') model_path = os.path.join(mcb.temp_dir, 'aborted_model.h5') print('\nAborting, saving current model to %s.' % (mlflow.get_artifact_uri() + '/aborted_model.h5')) - model.save(model_path, save_format='h5') + save_model(model, model_path) mlflow.log_artifact(model_path) os.remove(model_path) raise diff --git a/delta/subcommands/train.py b/delta/subcommands/train.py index e94a7491..85b0ee31 100644 --- a/delta/subcommands/train.py +++ b/delta/subcommands/train.py @@ -33,6 +33,7 @@ from delta.ml.train import train from delta.ml.model_parser import config_model from delta.ml.layers import ALL_LAYERS +from delta.ml.io import save_model #tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG) @@ -74,7 +75,7 @@ def main(options): model, _ = train(model, ids, tc) if options.model is not None: - model.save(options.model) + save_model(model, options.model) except KeyboardInterrupt: print() print('Training cancelled.') diff --git a/scripts/model2config b/scripts/model2config old mode 100644 new mode 100755 index 51c4162b..3b74504c --- a/scripts/model2config +++ b/scripts/model2config @@ -2,6 +2,7 @@ import tensorflow as tf from argparse import ArgumentParser +import h5py import pathlib parser = ArgumentParser(description='Converts a neural network in a *.h5 file to the DELTA configuration langauge') @@ -9,8 +10,15 @@ parser.add_argument('model_name', type=pathlib.Path, help='The model to convert' args = parser.parse_args() -a = tf.keras.models.load_model(args.model_name) +print('Configuration File') +with h5py.File(args.model_name, 'r') as f: + if 'delta' not in f.attrs: + print(' - Not Available\n') + else: + print('\n' + f.attrs['delta'] + '\n') +a = tf.keras.models.load_model(args.model_name) +print('Network Structure') for l in a.layers: print('\t- ', type(l).__name__) configs = l.get_config() diff --git a/setup.py b/setup.py index 1a9a62ca..36ba9160 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,8 @@ 'mlflow', 'portalocker', 'appdirs', - 'gdal' + 'gdal', + 'h5py' ], scripts=scripts, include_package_data = True,