Skip to content

Commit

Permalink
Add config file to h5 files saved by delta. (#20)
Browse files Browse the repository at this point in the history
Also fix bug causing crash with validation.
  • Loading branch information
bcoltin authored Aug 13, 2020
1 parent ffce6d0 commit ad01135
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 27 deletions.
18 changes: 13 additions & 5 deletions delta/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,22 @@ def register_arg(self, field, argname, **kwargs):
kwargs['default'] = _NotSpecified
self._cmd_args[argname] = (field, kwargs)

def to_dict(self) -> dict:
"""
Returns a dictionary representing the config object.
"""
if isinstance(self._config_dict, dict):
exp = self._config_dict.copy()
for (name, c) in self._components.items():
exp[name] = c.to_dict()
return exp
return self._config_dict

def export(self) -> str:
"""
Returns a YAML string of all configuration options.
Returns a YAML string of all configuration options, from to_dict.
"""
exp = self._config_dict.copy()
for (name, c) in self._components.items():
exp[name] = c.export()
return yaml.dump(exp)
return yaml.dump(self.to_dict())

def _set_field(self, name : str, value : str, base_dir : str):
if name not in self._fields:
Expand Down
1 change: 1 addition & 0 deletions delta/imagery/imagery_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def __len__(self):
def _load_dict(self, d : dict, base_dir):
if not d:
return
self._config_dict = d
self._classes = []
if isinstance(d, int):
for i in range(d):
Expand Down
32 changes: 32 additions & 0 deletions delta/ml/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright © 2020, United States Government, as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All rights reserved.
#
# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
# licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Functions for IO specific to ML.
"""

import h5py

from delta.config import config

def save_model(model, filename):
"""
Save a model. Includes DELTA configuration.
"""
model.save(filename, save_format='h5')
with h5py.File(filename, 'r+') as f:
f.attrs['delta'] = config.export()
22 changes: 7 additions & 15 deletions delta/ml/ml_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,25 +106,17 @@ def _load_dict(self, d : dict, base_dir):
self._config_dict['layers'] = None
elif 'layers' in d:
self._config_dict['yaml_file'] = None

def as_dict(self) -> dict:
"""
Returns a dictionary representing the network model for use by `delta.ml.model_parser`.
"""
yaml_file = self._config_dict['yaml_file']
if yaml_file is not None:
if self._config_dict['layers'] is not None:
raise ValueError('Specified both yaml file and layers in model.')

if 'yaml_file' in d and 'layers' in d and d['yaml_file'] is not None and d['layers'] is not None:
raise ValueError('Specified both yaml file and layers in model.')
if 'yaml_file' in d and d['yaml_file'] is not None:
yaml_file = d['yaml_file']
resource = os.path.join('config', yaml_file)
if not os.path.exists(yaml_file) and pkg_resources.resource_exists('delta', resource):
yaml_file = pkg_resources.resource_filename('delta', resource)
if not os.path.exists(yaml_file):
raise ValueError('Model yaml_file does not exist: ' + yaml_file)
#print('Opening model file: ' + yaml_file)
with open(yaml_file, 'r') as f:
return yaml.safe_load(f)
return self._config_dict
self._config_dict.update(yaml.safe_load(f))

class NetworkConfig(config.DeltaConfigComponent):
def __init__(self):
Expand Down Expand Up @@ -166,7 +158,7 @@ def images(self) -> ImageSet:
if self.__images is None:
(self.__images, self.__labels) = load_images_labels(self._components['images'],
self._components['labels'],
config.dataset.classes)
config.config.dataset.classes)
return self.__images

def labels(self) -> ImageSet:
Expand All @@ -176,7 +168,7 @@ def labels(self) -> ImageSet:
if self.__labels is None:
(self.__images, self.__labels) = load_images_labels(self._components['images'],
self._components['labels'],
config.dataset.classes)
config.config.dataset.classes)
return self.__labels

class TrainingConfig(config.DeltaConfigComponent):
Expand Down
2 changes: 1 addition & 1 deletion delta/ml/model_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,4 @@ def config_model(num_bands: int) -> Callable[[], tensorflow.keras.models.Sequent
'in_dims' : in_data_shape[0] * in_data_shape[1] * in_data_shape[2],
'num_bands' : in_data_shape[2]}

return model_from_dict(config.train.network.model.as_dict(), params_exposed)
return model_from_dict(config.train.network.model.to_dict(), params_exposed)
7 changes: 4 additions & 3 deletions delta/ml/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from delta.imagery.imagery_dataset import ImageryDataset
from delta.imagery.imagery_dataset import AutoencoderDataset
from .layers import DeltaLayer
from .io import save_model

def _devices(num_gpus):
'''
Expand Down Expand Up @@ -132,7 +133,7 @@ def on_train_batch_end(self, batch, logs=None):
mlflow.log_metric(k, logs[k], step=batch)
if config.mlflow.checkpoints.frequency() and batch % config.mlflow.checkpoints.frequency() == 0:
filename = os.path.join(self.temp_dir, '%d.h5' % (batch))
self.model.save(filename, save_format='h5')
save_model(self.model, filename)
if config.mlflow.checkpoints.only_save_latest():
old = filename
filename = os.path.join(self.temp_dir, 'latest.h5')
Expand Down Expand Up @@ -228,7 +229,7 @@ def train(model_fn, dataset : ImageryDataset, training_spec):
if config.mlflow.enabled():
model_path = os.path.join(mcb.temp_dir, 'final_model.h5')
print('\nFinished, saving model to %s.' % (mlflow.get_artifact_uri() + '/final_model.h5'))
model.save(model_path, save_format='h5')
save_model(model, model_path)
mlflow.log_artifact(model_path)
os.remove(model_path)
mlflow.log_param('Status', 'Completed')
Expand All @@ -238,7 +239,7 @@ def train(model_fn, dataset : ImageryDataset, training_spec):
mlflow.end_run('FAILED')
model_path = os.path.join(mcb.temp_dir, 'aborted_model.h5')
print('\nAborting, saving current model to %s.' % (mlflow.get_artifact_uri() + '/aborted_model.h5'))
model.save(model_path, save_format='h5')
save_model(model, model_path)
mlflow.log_artifact(model_path)
os.remove(model_path)
raise
Expand Down
3 changes: 2 additions & 1 deletion delta/subcommands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from delta.ml.train import train
from delta.ml.model_parser import config_model
from delta.ml.layers import ALL_LAYERS
from delta.ml.io import save_model

#tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)

Expand Down Expand Up @@ -74,7 +75,7 @@ def main(options):
model, _ = train(model, ids, tc)

if options.model is not None:
model.save(options.model)
save_model(model, options.model)
except KeyboardInterrupt:
print()
print('Training cancelled.')
Expand Down
10 changes: 9 additions & 1 deletion scripts/model2config
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@

import tensorflow as tf
from argparse import ArgumentParser
import h5py
import pathlib

parser = ArgumentParser(description='Converts a neural network in a *.h5 file to the DELTA configuration langauge')
parser.add_argument('model_name', type=pathlib.Path, help='The model to convert')

args = parser.parse_args()

a = tf.keras.models.load_model(args.model_name)
print('Configuration File')
with h5py.File(args.model_name, 'r') as f:
if 'delta' not in f.attrs:
print(' - Not Available\n')
else:
print('\n' + f.attrs['delta'] + '\n')

a = tf.keras.models.load_model(args.model_name)
print('Network Structure')
for l in a.layers:
print('\t- ', type(l).__name__)
configs = l.get_config()
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
'mlflow',
'portalocker',
'appdirs',
'gdal'
'gdal',
'h5py'
],
scripts=scripts,
include_package_data = True,
Expand Down

0 comments on commit ad01135

Please sign in to comment.