diff --git a/.cspell.json b/.cspell.json index 07b8655086..5c81bd0d5f 100644 --- a/.cspell.json +++ b/.cspell.json @@ -9,6 +9,7 @@ "bincount", "cacheable", "chisquare", + "commandline", "conda", "configspace", "dataframe", @@ -62,6 +63,7 @@ "skopt", "smac", "sqlalchemy", + "subcmd", "subschema", "subschemas", "tolist", diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py index 704966c7b0..8214ef9796 100644 --- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py @@ -46,7 +46,7 @@ def __init__(self, tunables: TunableGroups, service: Optional[Service], config: self._start_with_defaults: bool = bool( strtobool(str(self._config.pop('start_with_defaults', True)))) self._max_iter = int(self._config.pop('max_iterations', 100)) - self._opt_target = self._config.pop('optimization_target', 'score') + self._opt_target = str(self._config.pop('optimization_target', 'score')) self._opt_sign = {"min": 1, "max": -1}[self._config.pop('optimization_direction', 'min')] def __repr__(self) -> str: diff --git a/mlos_bench/mlos_bench/services/local/local_exec.py b/mlos_bench/mlos_bench/services/local/local_exec.py index 85a640d039..b2b48ce6f0 100644 --- a/mlos_bench/mlos_bench/services/local/local_exec.py +++ b/mlos_bench/mlos_bench/services/local/local_exec.py @@ -13,7 +13,7 @@ import subprocess import sys -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING from mlos_bench.services.base_service import Service from mlos_bench.services.local.temp_dir_context import TempDirContextService @@ -25,6 +25,41 @@ _LOG = logging.getLogger(__name__) +def split_cmdline(cmdline: str) -> Iterable[List[str]]: + """ + A single command line may contain multiple commands separated by + special characters (e.g., &&, ||, etc.) so further split the + commandline into an array of subcommand arrays. + + Parameters + ---------- + cmdline: str + The commandline to split. + + Yields + ------ + Iterable[List[str]] + A list of subcommands or separators, each one a list of tokens. + Can be rejoined as a flattened array. + """ + cmdline_tokens = shlex.shlex(cmdline, posix=True, punctuation_chars=True) + cmdline_tokens.whitespace_split = True + subcmd = [] + for token in cmdline_tokens: + if token[0] not in cmdline_tokens.punctuation_chars: + subcmd.append(token) + else: + # Separator encountered. Yield any non-empty previous subcmd we accumulated. + if subcmd: + yield subcmd + # Also return the separators. + yield [token] + subcmd = [] + # Return the trailing subcommand. + if subcmd: + yield subcmd + + class LocalExecService(TempDirContextService, SupportsLocalExec): """ Collection of methods to run scripts and commands in an external process @@ -98,6 +133,34 @@ def local_exec(self, script_lines: Iterable[str], return (return_code, stdout, stderr) + def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]: + """ + Resolves local script path (first token) in the (sub)command line + tokens to its full path. + + Parameters + ---------- + subcmd_tokens : List[str] + The previously split tokens of the subcmd. + + Returns + ------- + List[str] + A modified sub command line with the script paths resolved. + """ + script_path = self.config_loader_service.resolve_path(subcmd_tokens[0]) + # Special case check for lone `.` which means both `source` and + # "current directory" (which isn't executable) in posix shells. + if os.path.exists(script_path) and os.path.isfile(script_path): + # If the script exists, use it. + subcmd_tokens[0] = os.path.abspath(script_path) + # Also check if it is a python script and prepend the currently + # executing python executable path to avoid requiring + # executable mode bits or a shebang. + if script_path.strip().lower().endswith(".py"): + subcmd_tokens.insert(0, sys.executable) + return subcmd_tokens + def _local_exec_script(self, script_line: str, env_params: Optional[Mapping[str, "TunableValue"]], cwd: str) -> Tuple[int, str, str]: @@ -107,9 +170,7 @@ def _local_exec_script(self, script_line: str, Parameters ---------- script_line : str - Line of the script to tun in the local process. - args : Iterable[str] - Command line arguments for the script. + Line of the script to run in the local process. env_params : Mapping[str, Union[int, float, str]] Environment variables. cwd : str @@ -120,18 +181,12 @@ def _local_exec_script(self, script_line: str, (return_code, stdout, stderr) : (int, str, str) A 3-tuple of return code, stdout, and stderr of the script process. """ - cmd = shlex.split(script_line) - script_path = self.config_loader_service.resolve_path(cmd[0]) - # special case handling for leading lone `.` character, which also means `source` - # (as in, include, but not execute) in shell syntax - if os.path.exists(script_path) and not os.path.isdir(script_path): - script_path = os.path.abspath(script_path) - else: - script_path = cmd[0] # rollback to the original value - - cmd = [script_path] + cmd[1:] - if script_path.strip().lower().endswith(".py"): - cmd = [sys.executable] + cmd + # Split the command line into set of subcmd tokens. + # For each subcmd, perform path resolution fixups for any scripts being executed. + subcmds = split_cmdline(script_line) + subcmds = [self._resolve_cmdline_script_path(subcmd) for subcmd in subcmds] + # Finally recombine all of the fixed up subcmd tokens into the original. + cmd = [token for subcmd in subcmds for token in subcmd] env: Dict[str, str] = {} if env_params: @@ -144,12 +199,12 @@ def _local_exec_script(self, script_line: str, env_copy.update(env) env = env_copy - _LOG.info("Run: %s", cmd) - try: if sys.platform != 'win32': cmd = [" ".join(cmd)] + _LOG.info("Run: %s", cmd) + proc = subprocess.run(cmd, env=env or None, cwd=cwd, shell=True, text=True, check=False, capture_output=True) diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py index cdf49c106e..b637cf8f81 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py @@ -10,7 +10,7 @@ import pytest import pandas -from mlos_bench.services.local.local_exec import LocalExecService +from mlos_bench.services.local.local_exec import LocalExecService, split_cmdline from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.util import path_join @@ -19,6 +19,32 @@ # `local_exec_service` fixture as both a function and a parameter. +def test_split_cmdline() -> None: + """ + Test splitting a commandline into subcommands. + """ + cmdline = ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)" + assert list(split_cmdline(cmdline)) == [ + ['.', 'env.sh'], + ['&&'], + ['('], + ['echo', 'hello'], + ['&&'], + ['echo', 'world'], + ['|'], + ['tee'], + ['>'], + ['/tmp/test'], + ['||'], + ['echo', 'foo'], + ['&&'], + ['echo', '$var'], + [';'], + ['true'], + [')'], + ] + + @pytest.fixture def local_exec_service() -> LocalExecService: """ @@ -27,6 +53,22 @@ def local_exec_service() -> LocalExecService: return LocalExecService(parent=ConfigPersistenceService()) +def test_resolve_script(local_exec_service: LocalExecService) -> None: + """ + Test local script resolution logic with complex subcommand names. + """ + script = "os/linux/runtime/scripts/local/generate_kernel_config_script.py" + script_abspath = local_exec_service.config_loader_service.resolve_path(script) + orig_cmdline = f". env.sh && {script}" + expected_cmdline = f". env.sh && {script_abspath}" + subcmds_tokens = split_cmdline(orig_cmdline) + # pylint: disable=protected-access + subcmds_tokens = [local_exec_service._resolve_cmdline_script_path(subcmd_tokens) for subcmd_tokens in subcmds_tokens] + cmdline_tokens = [token for subcmd_tokens in subcmds_tokens for token in subcmd_tokens] + expanded_cmdline = " ".join(cmdline_tokens) + assert expanded_cmdline == expected_cmdline + + def test_run_script(local_exec_service: LocalExecService) -> None: """ Run a script locally and check the results. @@ -86,6 +128,12 @@ def test_run_script_read_csv(local_exec_service: LocalExecService) -> None: assert stderr.strip() == "" data = pandas.read_csv(path_join(temp_dir, "output.csv")) + if sys.platform == 'win32': + # Workaround for Python's subprocess module on Windows adding a + # space inbetween the col1,col2 arg and the redirect symbol which + # cmd poorly interprets as being part of the original string arg. + # Without this, we get "col2 " as the second column name. + data.rename(str.rstrip, axis='columns', inplace=True) assert all(data.col1 == [111, 333]) assert all(data.col2 == [222, 444])