Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add config file to h5 files saved by delta. #20

Merged
merged 2 commits into from
Aug 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions delta/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,22 @@ def register_arg(self, field, argname, **kwargs):
kwargs['default'] = _NotSpecified
self._cmd_args[argname] = (field, kwargs)

def to_dict(self) -> dict:
"""
Returns a dictionary representing the config object.
"""
if isinstance(self._config_dict, dict):
exp = self._config_dict.copy()
for (name, c) in self._components.items():
exp[name] = c.to_dict()
return exp
return self._config_dict

def export(self) -> str:
"""
Returns a YAML string of all configuration options.
Returns a YAML string of all configuration options, from to_dict.
"""
exp = self._config_dict.copy()
for (name, c) in self._components.items():
exp[name] = c.export()
return yaml.dump(exp)
return yaml.dump(self.to_dict())

def _set_field(self, name : str, value : str, base_dir : str):
if name not in self._fields:
Expand Down
1 change: 1 addition & 0 deletions delta/imagery/imagery_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def __len__(self):
def _load_dict(self, d : dict, base_dir):
if not d:
return
self._config_dict = d
self._classes = []
if isinstance(d, int):
for i in range(d):
Expand Down
32 changes: 32 additions & 0 deletions delta/ml/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright © 2020, United States Government, as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All rights reserved.
#
# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
# licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Functions for IO specific to ML.
"""

import h5py

from delta.config import config

def save_model(model, filename):
"""
Save a model. Includes DELTA configuration.
"""
model.save(filename, save_format='h5')
with h5py.File(filename, 'r+') as f:
f.attrs['delta'] = config.export()
22 changes: 7 additions & 15 deletions delta/ml/ml_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,25 +106,17 @@ def _load_dict(self, d : dict, base_dir):
self._config_dict['layers'] = None
elif 'layers' in d:
self._config_dict['yaml_file'] = None

def as_dict(self) -> dict:
"""
Returns a dictionary representing the network model for use by `delta.ml.model_parser`.
"""
yaml_file = self._config_dict['yaml_file']
if yaml_file is not None:
if self._config_dict['layers'] is not None:
raise ValueError('Specified both yaml file and layers in model.')

if 'yaml_file' in d and 'layers' in d and d['yaml_file'] is not None and d['layers'] is not None:
raise ValueError('Specified both yaml file and layers in model.')
if 'yaml_file' in d and d['yaml_file'] is not None:
yaml_file = d['yaml_file']
resource = os.path.join('config', yaml_file)
if not os.path.exists(yaml_file) and pkg_resources.resource_exists('delta', resource):
yaml_file = pkg_resources.resource_filename('delta', resource)
if not os.path.exists(yaml_file):
raise ValueError('Model yaml_file does not exist: ' + yaml_file)
#print('Opening model file: ' + yaml_file)
with open(yaml_file, 'r') as f:
return yaml.safe_load(f)
return self._config_dict
self._config_dict.update(yaml.safe_load(f))

class NetworkConfig(config.DeltaConfigComponent):
def __init__(self):
Expand Down Expand Up @@ -166,7 +158,7 @@ def images(self) -> ImageSet:
if self.__images is None:
(self.__images, self.__labels) = load_images_labels(self._components['images'],
self._components['labels'],
config.dataset.classes)
config.config.dataset.classes)
return self.__images

def labels(self) -> ImageSet:
Expand All @@ -176,7 +168,7 @@ def labels(self) -> ImageSet:
if self.__labels is None:
(self.__images, self.__labels) = load_images_labels(self._components['images'],
self._components['labels'],
config.dataset.classes)
config.config.dataset.classes)
return self.__labels

class TrainingConfig(config.DeltaConfigComponent):
Expand Down
2 changes: 1 addition & 1 deletion delta/ml/model_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,4 @@ def config_model(num_bands: int) -> Callable[[], tensorflow.keras.models.Sequent
'in_dims' : in_data_shape[0] * in_data_shape[1] * in_data_shape[2],
'num_bands' : in_data_shape[2]}

return model_from_dict(config.train.network.model.as_dict(), params_exposed)
return model_from_dict(config.train.network.model.to_dict(), params_exposed)
7 changes: 4 additions & 3 deletions delta/ml/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from delta.imagery.imagery_dataset import ImageryDataset
from delta.imagery.imagery_dataset import AutoencoderDataset
from .layers import DeltaLayer
from .io import save_model

def _devices(num_gpus):
'''
Expand Down Expand Up @@ -132,7 +133,7 @@ def on_train_batch_end(self, batch, logs=None):
mlflow.log_metric(k, logs[k], step=batch)
if config.mlflow.checkpoints.frequency() and batch % config.mlflow.checkpoints.frequency() == 0:
filename = os.path.join(self.temp_dir, '%d.h5' % (batch))
self.model.save(filename, save_format='h5')
save_model(self.model, filename)
if config.mlflow.checkpoints.only_save_latest():
old = filename
filename = os.path.join(self.temp_dir, 'latest.h5')
Expand Down Expand Up @@ -228,7 +229,7 @@ def train(model_fn, dataset : ImageryDataset, training_spec):
if config.mlflow.enabled():
model_path = os.path.join(mcb.temp_dir, 'final_model.h5')
print('\nFinished, saving model to %s.' % (mlflow.get_artifact_uri() + '/final_model.h5'))
model.save(model_path, save_format='h5')
save_model(model, model_path)
mlflow.log_artifact(model_path)
os.remove(model_path)
mlflow.log_param('Status', 'Completed')
Expand All @@ -238,7 +239,7 @@ def train(model_fn, dataset : ImageryDataset, training_spec):
mlflow.end_run('FAILED')
model_path = os.path.join(mcb.temp_dir, 'aborted_model.h5')
print('\nAborting, saving current model to %s.' % (mlflow.get_artifact_uri() + '/aborted_model.h5'))
model.save(model_path, save_format='h5')
save_model(model, model_path)
mlflow.log_artifact(model_path)
os.remove(model_path)
raise
Expand Down
3 changes: 2 additions & 1 deletion delta/subcommands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from delta.ml.train import train
from delta.ml.model_parser import config_model
from delta.ml.layers import ALL_LAYERS
from delta.ml.io import save_model

#tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)

Expand Down Expand Up @@ -74,7 +75,7 @@ def main(options):
model, _ = train(model, ids, tc)

if options.model is not None:
model.save(options.model)
save_model(model, options.model)
except KeyboardInterrupt:
print()
print('Training cancelled.')
Expand Down
10 changes: 9 additions & 1 deletion scripts/model2config
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@

import tensorflow as tf
from argparse import ArgumentParser
import h5py
import pathlib

parser = ArgumentParser(description='Converts a neural network in a *.h5 file to the DELTA configuration langauge')
parser.add_argument('model_name', type=pathlib.Path, help='The model to convert')

args = parser.parse_args()

a = tf.keras.models.load_model(args.model_name)
print('Configuration File')
with h5py.File(args.model_name, 'r') as f:
if 'delta' not in f.attrs:
print(' - Not Available\n')
else:
print('\n' + f.attrs['delta'] + '\n')

a = tf.keras.models.load_model(args.model_name)
print('Network Structure')
for l in a.layers:
print('\t- ', type(l).__name__)
configs = l.get_config()
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
'mlflow',
'portalocker',
'appdirs',
'gdal'
'gdal',
'h5py'
],
scripts=scripts,
include_package_data = True,
Expand Down