generated from entelecheia/hyperfast-python-template
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(hydra): add main.py and utils.py modules
- Loading branch information
1 parent
ecdd8e3
commit d68c7a8
Showing
2 changed files
with
219 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import functools | ||
from textwrap import dedent | ||
from typing import Any, Callable, Optional | ||
|
||
from hydra import version | ||
from hydra._internal.deprecation_warning import deprecation_warning | ||
from hydra._internal.utils import _run_hydra, get_args_parser | ||
from hydra.core.utils import _flush_loggers | ||
from hydra.main import _get_rerun_conf | ||
from hydra.types import TaskFunction | ||
from omegaconf import DictConfig | ||
|
||
_UNSPECIFIED_: Any = object() | ||
|
||
|
||
def main( | ||
config_path: Optional[str] = _UNSPECIFIED_, | ||
config_name: Optional[str] = None, | ||
version_base: Optional[str] = _UNSPECIFIED_, | ||
) -> Callable[[TaskFunction], Any]: | ||
""" | ||
:param config_path: The config path, a directory where Hydra will search for | ||
config files. This path is added to Hydra's searchpath. | ||
Relative paths are interpreted relative to the declaring python | ||
file. Alternatively, you can use the prefix `pkg://` to specify | ||
a python package to add to the searchpath. | ||
If config_path is None no directory is added to the Config search path. | ||
:param config_name: The name of the config (usually the file name without the .yaml extension) | ||
""" | ||
|
||
version.setbase(version_base) | ||
|
||
if config_path is _UNSPECIFIED_: | ||
if version.base_at_least("1.2"): | ||
config_path = None | ||
elif version_base is _UNSPECIFIED_: | ||
url = "https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path" | ||
deprecation_warning( | ||
message=dedent( | ||
f""" | ||
config_path is not specified in @hydra.main(). | ||
See {url} for more information.""" | ||
), | ||
stacklevel=2, | ||
) | ||
config_path = "." | ||
else: | ||
config_path = "." | ||
|
||
def main_decorator(task_function: TaskFunction) -> Callable[[], None]: | ||
@functools.wraps(task_function) | ||
def decorated_main(cfg_passthrough: Optional[DictConfig] = None) -> Any: | ||
if cfg_passthrough is not None: | ||
return task_function(cfg_passthrough) | ||
args_parser = get_args_parser() | ||
args = args_parser.parse_args() | ||
if args.experimental_rerun is not None: | ||
cfg = _get_rerun_conf(args.experimental_rerun, args.overrides) | ||
task_function(cfg) | ||
_flush_loggers() | ||
else: | ||
# no return value from run_hydra() as it may sometime actually run the task_function | ||
# multiple times (--multirun) | ||
_run_hydra( | ||
args=args, | ||
args_parser=args_parser, | ||
task_function=task_function, | ||
config_path=config_path, | ||
config_name=config_name, | ||
) | ||
|
||
return decorated_main | ||
|
||
return main_decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import argparse | ||
import logging.config | ||
import os | ||
import sys | ||
from typing import Optional | ||
|
||
from hydra._internal.utils import ( | ||
_run_app, | ||
create_automatic_config_search_path, | ||
detect_calling_file_or_module_from_stack_frame, | ||
detect_calling_file_or_module_from_task_function, | ||
detect_task_name, | ||
run_and_report, | ||
) | ||
from hydra.core.config_search_path import SearchPathQuery | ||
from hydra.core.utils import validate_config_path | ||
from hydra.errors import SearchPathException | ||
from hydra.types import TaskFunction | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def _run_hydra( | ||
args: argparse.Namespace, | ||
args_parser: argparse.ArgumentParser, | ||
task_function: TaskFunction, | ||
config_path: Optional[str], | ||
config_name: Optional[str], | ||
caller_stack_depth: int = 2, | ||
) -> None: | ||
from hydra._internal.hydra import Hydra | ||
from hydra.core.global_hydra import GlobalHydra | ||
|
||
if args.config_name is not None: | ||
config_name = args.config_name | ||
|
||
if args.config_path is not None: | ||
config_path = args.config_path | ||
|
||
( | ||
calling_file, | ||
calling_module, | ||
) = detect_calling_file_or_module_from_task_function(task_function) | ||
if calling_file is None and calling_module is None: | ||
( | ||
calling_file, | ||
calling_module, | ||
) = detect_calling_file_or_module_from_stack_frame(caller_stack_depth + 1) | ||
task_name = detect_task_name(calling_file, calling_module) | ||
|
||
validate_config_path(config_path) | ||
|
||
search_path = create_automatic_config_search_path( | ||
calling_file, calling_module, config_path | ||
) | ||
|
||
def add_conf_dir() -> None: | ||
if args.config_dir is not None: | ||
abs_config_dir = os.path.abspath(args.config_dir) | ||
if not os.path.isdir(abs_config_dir): | ||
raise SearchPathException( | ||
f"Additional config directory '{abs_config_dir}' not found" | ||
) | ||
search_path.prepend( | ||
provider="command-line", | ||
path=f"file://{abs_config_dir}", | ||
anchor=SearchPathQuery(provider="schema"), | ||
) | ||
|
||
run_and_report(add_conf_dir) | ||
hydra = run_and_report( | ||
lambda: Hydra.create_main_hydra2( | ||
task_name=task_name, config_search_path=search_path | ||
) | ||
) | ||
|
||
try: | ||
if args.help: | ||
hydra.app_help(config_name=config_name, args_parser=args_parser, args=args) | ||
sys.exit(0) | ||
has_show_cfg = args.cfg is not None | ||
if args.resolve and (not has_show_cfg and not args.help): | ||
raise ValueError( | ||
"The --resolve flag can only be used in conjunction with --cfg or --help" | ||
) | ||
if args.hydra_help: | ||
hydra.hydra_help( | ||
config_name=config_name, args_parser=args_parser, args=args | ||
) | ||
sys.exit(0) | ||
|
||
num_commands = ( | ||
args.run | ||
+ has_show_cfg | ||
+ args.multirun | ||
+ args.shell_completion | ||
+ (args.info is not None) | ||
) | ||
if num_commands > 1: | ||
raise ValueError( | ||
"Only one of --run, --multirun, --cfg, --info and --shell_completion can be specified" | ||
) | ||
if num_commands == 0: | ||
args.run = True | ||
|
||
overrides = args.overrides | ||
|
||
if args.run or args.multirun: | ||
run_mode = hydra.get_mode(config_name=config_name, overrides=overrides) | ||
_run_app( | ||
run=args.run, | ||
multirun=args.multirun, | ||
mode=run_mode, | ||
hydra=hydra, | ||
config_name=config_name, | ||
task_function=task_function, | ||
overrides=overrides, | ||
) | ||
elif args.cfg: | ||
run_and_report( | ||
lambda: hydra.show_cfg( | ||
config_name=config_name, | ||
overrides=args.overrides, | ||
cfg_type=args.cfg, | ||
package=args.package, | ||
resolve=args.resolve, | ||
) | ||
) | ||
elif args.shell_completion: | ||
run_and_report( | ||
lambda: hydra.shell_completion( | ||
config_name=config_name, overrides=args.overrides | ||
) | ||
) | ||
elif args.info: | ||
hydra.show_info( | ||
args.info, config_name=config_name, overrides=args.overrides | ||
) | ||
else: | ||
sys.stderr.write("Command not specified\n") | ||
sys.exit(1) | ||
finally: | ||
GlobalHydra.instance().clear() |