Skip to content

Commit

Permalink
Merge pull request #132 from HDI-Project/issue_131_add_demo_datasets
Browse files Browse the repository at this point in the history
Issue 131 add demo datasets
  • Loading branch information
csala authored Mar 15, 2019
2 parents 2f5226c + b82001e commit 2f0e975
Show file tree
Hide file tree
Showing 8 changed files with 637 additions and 105 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,8 @@ ENV/
# vim
.*.swp

# added by rjdiez
# IntelliJ Idea
.idea/

# cached datasets
mlprimitives/data/
89 changes: 0 additions & 89 deletions mlprimitives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,95 +6,6 @@
__email__ = 'dailabmit@gmail.com'
__version__ = '0.1.7-dev'

import argparse
import logging
import os
import warnings

from mlblocks import add_primitives_path, get_primitives_paths

from mlprimitives.evaluation import score_pipeline

MLPRIMITIVES_JSONS_PATH = os.path.join(os.path.dirname(__file__), 'jsons')
LOGGER = logging.getLogger(__name__)


def _logging_setup(verbosity=1):
log_level = (3 - verbosity) * 10
fmt = '%(asctime)s - %(levelname)s - %(message)s'
formatter = logging.Formatter(fmt)
LOGGER.setLevel(log_level)
LOGGER.propagate = False

console_handler = logging.StreamHandler()
console_handler.setLevel(log_level)
console_handler.setFormatter(formatter)
LOGGER.addHandler(console_handler)


def _test(args):
for pipeline in args.pipeline:
print('Scoring pipeline: {}'.format(pipeline))
score, stdev = score_pipeline(pipeline, args.splits)
print('Obtained Score: {:.4f} +/- {:.4f}'.format(score, stdev))


def _get_primitives(pattern):
primitives = list()
for base_path in get_primitives_paths():
if os.path.exists(base_path):
for filename in os.listdir(base_path):
if pattern in filename and filename.endswith('.json'):
primitives.append(filename[:-5])

return list(sorted(primitives))


def _list(args):
print('\n'.join(_get_primitives(args.pattern)))


def _parse_args():
parser = argparse.ArgumentParser(description='MLPrimitives Command Line Interface')

parser.add_argument(
'-p', '--primitives-path', action='append', help=(
'Path where primitives should be looked for. Use multiple '
'times in order to add multiple directories'
)
)
parser.add_argument('-v', '--verbose', action='count', default=0)

subparsers = parser.add_subparsers(title='action', help='Action to perform')

subparser = subparsers.add_parser('test', help='Test a single pipeline.')
subparser.set_defaults(action=_test)
subparser.add_argument('-s', '--splits', default=1, type=int,
help='Number of splits to use for Cross Validation')
subparser.add_argument('pipeline', nargs='+')

subparser = subparsers.add_parser('list', help='List available primitives')
subparser.set_defaults(action=_list)
subparser.add_argument('pattern', nargs='?', default='')

return parser.parse_args()


def _add_primitives_paths(paths):
if paths:
for path in paths:
add_primitives_path(path)


def _process_common_args(args):
_add_primitives_paths(args.primitives_path)
_logging_setup(args.verbose)


def _main():
warnings.filterwarnings('ignore', category=DeprecationWarning)

args = _parse_args()
_process_common_args(args)

args.action(args)
112 changes: 112 additions & 0 deletions mlprimitives/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-

"""MLPrimitives Command Line Interface module."""

import argparse
import logging
import os
import sys
import warnings

from mlblocks import add_primitives_path, get_primitives_paths

from mlprimitives.evaluation import score_pipeline

LOGGER = logging.getLogger(__name__)


def _logging_setup(verbosity=1):
logger = logging.getLogger()
log_level = (3 - verbosity) * 10
fmt = '%(asctime)s - %(levelname)s - %(message)s'
formatter = logging.Formatter(fmt)
logger.setLevel(log_level)
logger.propagate = False

console_handler = logging.StreamHandler()
console_handler.setLevel(log_level)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


def _test(args):
for pipeline in args.pipeline:
print('Scoring pipeline: {}'.format(pipeline))
score, stdev = score_pipeline(pipeline, args.splits)
print('Obtained Score: {:.4f} +/- {:.4f}'.format(score, stdev))


def _get_primitives(pattern):
primitives = list()
for base_path in get_primitives_paths():
if os.path.exists(base_path):
for filename in os.listdir(base_path):
if pattern in filename and filename.endswith('.json'):
primitives.append(filename[:-5])

return list(sorted(primitives))


def _list(args):
print('\n'.join(_get_primitives(args.pattern)))


class ArgumentParser(argparse.ArgumentParser):

def error(self, message):
sys.stderr.write('\nERROR: {}\n\n'.format(message))
self.print_help()
sys.exit(2)


def _get_parser():
parser = ArgumentParser(
description='MLPrimitives Command Line Interface')

parser.add_argument(
'-p', '--primitives-path', action='append', help=(
'Path where primitives should be looked for. Use multiple '
'times in order to add multiple directories'
)
)
parser.add_argument('-v', '--verbose', action='count', default=0)

subparsers = parser.add_subparsers(title='action', help='Action to perform')
parser.set_defaults(action=None)

subparser = subparsers.add_parser('test', help='Test a single pipeline.')
subparser.set_defaults(action=_test)
subparser.add_argument('-s', '--splits', default=1, type=int,
help='Number of splits to use for Cross Validation')
subparser.add_argument('pipeline', nargs='+')

subparser = subparsers.add_parser('list', help='List available primitives')
subparser.set_defaults(action=_list)
subparser.add_argument('pattern', nargs='?', default='')

return parser


def _add_primitives_paths(paths):
if paths:
for path in paths:
add_primitives_path(path)


def _process_common_args(args):
_add_primitives_paths(args.primitives_path)
_logging_setup(args.verbose)


def main():
warnings.filterwarnings('ignore', category=DeprecationWarning)

parser = _get_parser()
args = parser.parse_args()
if not args.action:
parser.print_help()
sys.exit(0)

_process_common_args(args)

args.action(args)
Loading

0 comments on commit 2f0e975

Please sign in to comment.