Skip to content

Commit

Permalink
Merge HDF5 file utilities (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Oct 22, 2016
1 parent 158189a commit 23c2029
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 200 deletions.
60 changes: 60 additions & 0 deletions tool/create_h5_data_from_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import h5py
import numpy as np

from luchador.env import get_env
from luchador.util import load_config


def parse_command_line_args():
from argparse import ArgumentParser as AP
ap = AP(
Description='Create ALE Environment state data'
)
ap.add_argument('env', help='YAML file contains environment config')
ap.add_argument('output', help='Output HDF5 file name')
ap.add_argument('key', help='Name of dataset in the output file')
ap.add_argument('--channel', type=int, default=4)
ap.add_argument('--batch', type=int, default=32)
return ap.parse_args()


def create_env(cfg_file):
cfg = load_config(cfg_file)
Environment = get_env(cfg['name'])
env = Environment(**cfg['args'])
print('\n{}'.format(env))
return env


def create_data(env, channel, batch):
samples = []
env.reset()
for _ in range(batch):
sample = []
for _ in range(channel):
outcome = env.step(0)
sample.append(outcome.observation)
if outcome.terminal:
env.reset()
samples.append(sample)
return np.asarray(samples, dtype=np.uint8)


def save(data, output_file, key='data'):
f = h5py.File(output_file, 'a')
if key in f:
del f[key]
f.create_dataset(key, data=data)
f.close()


def main():
args = parse_command_line_args()
env = create_env(args.env)
data = create_data(env, args.channel, args.batch)

save(data, args.output, args.key)


if __name__ == '__main__':
main()
21 changes: 0 additions & 21 deletions tool/delete_dataset_from_h5.py

This file was deleted.

193 changes: 193 additions & 0 deletions tool/edit_hdf5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#!/usr/bin/env python

"""Command line tool to edit HDF5 file"""

from __future__ import print_function

import sys
from collections import OrderedDict
from argparse import ArgumentParser as AP

import h5py
import numpy as np


def load_hdf5(filepath, mode='r'):
"""Load HDF5 file and unnest structure"""
return h5py.File(filepath, mode)


def unnest_hdf5(obj, prefix='', ret=None):
if ret is None:
ret = OrderedDict()

for key, value in obj.items():
path = '{}/{}'.format(prefix, key)
if isinstance(value, h5py.Group):
unnest_hdf5(value, path, ret)
else:
ret[path] = value
return ret


def get_dataset_summary(f):
return OrderedDict(
[(key, {
'dtype': value.dtype,
'shape': value.shape,
'mean': np.mean(value),
'sum': np.sum(value),
'max': np.max(value),
'min': np.min(value),
}) for key, value in f.items()])


def max_str(l):
return max(map(lambda e: len(str(e)), l))


def print_summary(summary):
dtype_len = max_str([s['dtype'] for s in summary.values()]) + 1
shape_len = max_str([s['shape'] for s in summary.values()]) + 1
path_len = max_str(summary.keys()) + 1
print (
'{path:{path_len}}{dtype:{dtype_len}}{shape:{shape_len}} '
'{sum:>10} {max:>10} {min:>10} {mean:>10}'
.format(
dtype='dtype', dtype_len=dtype_len,
shape='shape', shape_len=shape_len,
path='path', path_len=path_len,
sum='sum', max='max', min='min', mean='mean'
)
)
for path, s in summary.items():
print (
'{path:{path_len}}{dtype:{dtype_len}}{shape:{shape_len}} '
'{sum:10.3E} {max:10.3E} {min:10.3E} {mean:10.3E}'
.format(
dtype=s['dtype'], dtype_len=dtype_len,
shape=s['shape'], shape_len=shape_len,
path=path, path_len=path_len,
sum=s['sum'], max=s['max'], min=s['min'], mean=s['mean'],
)
)


class HDF5Editor(object):
def __init__(self):
ap = AP(
description='Inspect HDF5 Data'
)
ap.add_argument('command', choices=['inspect', 'delete', 'rename', 'view'])

args = ap.parse_args(sys.argv[1:2])
getattr(self, args.command)(sys.argv[2:])

def inspect(self, argv):
ap = AP(
description='List up datasets in the given file.',
usage='{} {} [-h] input_file'.format(sys.argv[0], 'inspect')
)
ap.add_argument('input_file', help='Input HDF5 file')
args = ap.parse_args(argv)

f = unnest_hdf5(load_hdf5(args.input_file))
print_summary(get_dataset_summary(f))

def delete(self, argv):
ap = AP(
description='Delete a dataset from H5 file',
usage=('{} {} [-h] input_file keys [keys ...]'
.format(sys.argv[0], 'delete'))
)
ap.add_argument('input_file', help='Input HDF5 file.')
ap.add_argument('keys', nargs='+', help='Names of dataset to delete')
ap.add_argument(
'--dry-run', '--dryrun', action='store_true',
help='Do not apply change to the file.'
)
args = ap.parse_args(argv)

f = load_hdf5(args.input_file, 'r+')
for key in args.keys:
if key not in f:
raise KeyError('Databset not found: {}'.format(key))

for key in args.keys:
print('{}Deleting key: {}'
.format('(dryrun) ' if args.dry_run else '', key))
if not args.dry_run:
del f[key]

def rename(self, argv):
ap = AP(
description='Rename a dataset in H5 file',
usage=('{} {} [-h] input_file old_key new_key'
.format(sys.argv[0], 'rename'))
)
ap.add_argument('input_file', help='Input H5 file.')
ap.add_argument('old_key', help='Dataset to rename')
ap.add_argument('new_key', help='New Dataset name')
ap.add_argument(
'--force', '-f',
help='Overwrite in case the dataset with new_key exists.'
)
ap.add_argument(
'--dry-run', '--dryrun', action='store_true',
help='Do not apply change to the file.'
)
args = ap.parse_args(argv)

f = load_hdf5(args.input_file, 'r+')
if args.old_key not in f:
raise KeyError('Dataset not found: {}'.format(args.old_key))

if args.new_key in f:
raise KeyError('Dataset exists: {}'.format(args.new_key))

print('{}Renaming {} to {}'.format(
'(dryrun) ' if args.dry_run else '', args.old_key, args.new_key))
if not args.dry_run:
f[args.new_key] = f[args.old_key]
del f[args.old_key]

def view(self, argv):
import matplotlib.pyplot as plt

ap = AP(
description='Visualize output from convolution',
)
ap.add_argument('input_file', help='Input H5 file.')
ap.add_argument('key', help='Datasets to visualize')
ap.add_argument(
'--batch', type=int, default=0,
help='Batch number to visualize'
)
ap.add_argument(
'--format', default='NCHW',
help='Data format. Either NCHW or NHWC. Default: NCHW'
)
args = ap.parse_args(argv)

f = load_hdf5(args.input_file, 'r')
data = np.asarray(f[args.key])
if args.format == 'NHWC':
data.transpose((0, 3, 1, 2))

n_filters = data.shape[1]
n_rows = np.floor(np.sqrt(n_filters))
n_cols = np.ceil(n_filters / n_rows)

vmin, vmax = data.min(), data.max()
fig = plt.figure()
fig.suptitle('{}\nBatch: {}'.format(args.input_file, args.batch))
for index, filter_ in enumerate(data[args.batch], start=1):
axis = fig.add_subplot(n_rows, n_cols, index)
axis.imshow(filter_, cmap='Greys', vmin=vmin, vmax=vmax)
axis.set_title('Filter: {}'.format(index))
print('Plot ready')
plt.show()


if __name__ == '__main__':
HDF5Editor()
83 changes: 0 additions & 83 deletions tool/inspect_h5.py

This file was deleted.

22 changes: 0 additions & 22 deletions tool/rename_dataset_in_h5_data.py

This file was deleted.

Loading

0 comments on commit 23c2029

Please sign in to comment.