-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from HEP-DL/feature/kwierman_asyncio
Feature/kwierman asyncio
- Loading branch information
Showing
36 changed files
with
630 additions
and
494 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,103 +1,17 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import click | ||
import logging | ||
from dl_data_validation_toolset.framework.configuration import Configuration | ||
from dl_data_validation_toolset.framework.report_gen import ReportGenerator | ||
from dl_data_validation_toolset.framework.scanner import Scanner | ||
from dl_data_validation_toolset.framework.base_test import BaseTest | ||
from dl_data_validation_toolset.framework.report import FileReport | ||
from dl_data_validation_toolset.data_tests import initialize | ||
|
||
|
||
@click.command() | ||
@click.option('-n', default=1, type=click.INT) | ||
@click.option('--scale', default=1, type=click.INT) | ||
@click.option('--thresh', default=25, type=click.INT) | ||
@click.argument('input_file', nargs=1) | ||
def print_dl_images(n, scale, thresh, input_file): | ||
logging.basicConfig(level=logging.DEBUG) | ||
import h5py | ||
from scipy.misc import imsave | ||
from scipy.stats import threshold | ||
import numpy as np | ||
input_file = h5py.File(input_file, 'r') | ||
wires = input_file['image/wires'] | ||
rawdigits = input_file['image/rawdigits'] | ||
logging.info("""Producing {} images with | ||
scale {} and threshold {}""".format(n, scale, thresh)) | ||
for index in range(n): | ||
try: | ||
image = wires[index] | ||
logging.info("Image: min: {}, max: {}".format(np.min(image), | ||
np.max(image))) | ||
buff = np.ndarray(shape=(image.shape[1], image.shape[2], | ||
image.shape[0]), | ||
dtype=np.uint8) | ||
for i in range(3): | ||
buff[:, :, i] = image[i, :, :] | ||
buff = buff * scale | ||
buff = threshold(buff, threshmin=thresh) + threshold(buff, | ||
threshmax=-thresh) | ||
logging.info("Buffer: min: {}, max: {}".format(np.min(buff), | ||
np.max(buff))) | ||
imsave('wires_{}.png'.format(index), buff) | ||
logging.info('wires_{}.png created'.format(index)) | ||
except Exception as e: | ||
logging.warning(e) | ||
try: | ||
image = rawdigits[index] | ||
logging.info("Image: min: {}, max: {}".format(np.min(image), | ||
np.max(image))) | ||
buff = np.ndarray(shape=(image.shape[1], image.shape[2], | ||
image.shape[0]), dtype=np.uint8) | ||
for i in range(3): | ||
buff[:, :, i] = image[i, :, :] | ||
buff = buff * scale | ||
buff = threshold(buff, threshmin=thresh) + threshold(buff, | ||
threshmax=-thresh) | ||
logging.info("Buffer: min: {}, max: {}".format(np.min(buff), | ||
np.max(buff))) | ||
imsave('digits_{}.png'.format(index), buff) | ||
logging.info('digits_{}.png created'.format(index)) | ||
except Exception as e: | ||
logging.warning(e) | ||
from dl_data_validation_toolset.framework.report_gen import TopReportGenerator | ||
|
||
|
||
@click.command() | ||
@click.option('--config', default=None, type=click.Path()) | ||
def generate_report(config): | ||
def main(config): | ||
logging.basicConfig(level=logging.DEBUG) | ||
logging.info("Starting") | ||
initialize() | ||
logging.info("Tests to perform: {}".format(BaseTest.__subclasses__())) | ||
configuration = None | ||
if config is None: | ||
configuration = Configuration.default() | ||
else: | ||
configuration = Configuration(config) | ||
scanner = Scanner(configuration.scan_paths) | ||
scan_results = scanner.scan() | ||
logging.debug(scan_results) | ||
|
||
# Now that we've located the files and start a report, let's | ||
# create some tests | ||
file_reports = [] | ||
for file in scan_results[0]: | ||
report = FileReport(file) | ||
for test_case in BaseTest.__subclasses__(): | ||
logging.debug(test_case) | ||
test_case(file).validate(report) | ||
# create an image with this file | ||
file_reports.append(report) | ||
|
||
# TODO: make this more comprehensive | ||
group_reports = scan_results[1] | ||
|
||
rep_gen = ReportGenerator(configuration.results_path) | ||
rep_gen.generate(file_reports, group_reports) | ||
if configuration.tar: | ||
rep_gen.tarball() | ||
else: | ||
rep_gen.move_to_results() | ||
logging.info("Finished") | ||
configuration = Configuration() | ||
if config is not None: | ||
configuration.configure(config) | ||
configuration.scan() | ||
top = TopReportGenerator() | ||
top.generate(configuration) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,81 @@ | ||
import yaml | ||
import logging | ||
import json | ||
import os | ||
|
||
|
||
class Group: | ||
""" | ||
Defines a named group of files. | ||
example instantiation: | ||
.. code_block: python | ||
g= Group("myfile.h5",'/points/to/your/dir/','my_group') | ||
""" | ||
def __init__(self, file, base_dir, name): | ||
self.group = name | ||
self.files = [file, ] | ||
self.dir = base_dir | ||
|
||
@property | ||
def full_filenames(self): | ||
return [os.path.join(self.dir, i) for i in self.files] | ||
|
||
|
||
class Configuration(object): | ||
""" | ||
De-serializes configuration into a useable object by other | ||
objects. | ||
""" | ||
logger = logging.getLogger('config') | ||
|
||
def __init__(self, path): | ||
if path is not None: | ||
self.data = yaml.load(open(path, 'r')) | ||
self.scan_paths = [os.path.join(os.getcwd(), | ||
i) for i in self.data['scan_paths']] | ||
self.results_path = os.path.join(os.getcwd(), | ||
self.data['results_path']) | ||
def __init__(self): | ||
self.data = {} | ||
self.scan_paths = [os.path.join(os.getcwd(), 'data')] | ||
self.results_path = os.path.join(os.getcwd(), 'results') | ||
self.groups = [] | ||
self.tar = True | ||
self.name = "DL Data Validation Report" | ||
|
||
def configure(self, path): | ||
self.logger.debug("Opening file: {}".format(path)) | ||
with open(path, 'r') as input_file: | ||
|
||
self.data = json.load(input_file) | ||
self.scan_paths = self.data['scan_paths'] | ||
self.results_path = self.data['results_path'] | ||
self.results_path = os.path.abspath(self.results_path) | ||
self.tar = bool(self.data['tar']) | ||
self.logger.info("Loaded Config") | ||
self.logger.info(self.data) | ||
|
||
@staticmethod | ||
def default(): | ||
c = Configuration(None) | ||
Configuration.logger.info("Using default configuration") | ||
c.scan_paths = [os.path.join(os.getcwd(), 'data')] | ||
c.results_path = os.path.join(os.getcwd(), 'results') | ||
c.tar = True | ||
return c | ||
def create_new_group(self, file, path, group): | ||
msg = "Creating new group: {} at {}".format(group, | ||
path) | ||
self.logger.debug(msg) | ||
return Group(file, path, group) | ||
|
||
def scan(self): | ||
""" | ||
Creates groups of files to be used in the parallelized framework. | ||
""" | ||
self.groups = [] | ||
for base_dir in self.scan_paths: | ||
self.logger.info("Entering: {}".format(base_dir)) | ||
files = os.listdir(base_dir) | ||
for file in files: | ||
if file.endswith('.h5'): | ||
self.logger.debug("Adding file: {} to files".format(file)) | ||
# get the group out | ||
file_parts = file.split('_') | ||
file_group = '_'.join(file_parts[:-1]) | ||
group_exists = False | ||
for group in self.groups: | ||
if group.group == file_group and group.dir == base_dir: | ||
group_exists = True | ||
group.files.append(file) | ||
if not group_exists: | ||
self.groups.append(self.create_new_group(file, base_dir, | ||
file_group)) |
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .base import BaseReport | ||
import os | ||
|
||
|
||
class TopReport(BaseReport): | ||
def __init__(self, name): | ||
self.name = name | ||
self.groups = [] | ||
|
||
def render(self, directory): | ||
with open(os.path.join(directory, 'index.html'), 'w') as index_out: | ||
self.logger.info("Writing Index Page") | ||
index_out.write(self.index_template.render(title="DL Data Report", | ||
top_report=self)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from dl_data_validation_toolset import templates | ||
from mako.lookup import TemplateLookup | ||
import logging | ||
import os | ||
|
||
|
||
class BaseReport(object): | ||
logger = logging.getLogger("frmwk.report.base") | ||
template_directory = os.path.dirname(templates.__file__) | ||
|
||
@property | ||
def lookup(self): | ||
return TemplateLookup(directories=[self.template_directory]) | ||
|
||
@property | ||
def index_template(self): | ||
return self.lookup.get_template('index.mako') | ||
|
||
@property | ||
def file_template(self): | ||
return self.lookup.get_template('file.mako') | ||
|
||
@property | ||
def group_template(self): | ||
return self.lookup.get_template('group.mako') | ||
|
||
def render(self, directory): | ||
self.logger.error("Calling default report render!") |
Oops, something went wrong.