Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlos_bench: shell command line parsing improvements #461

Merged
merged 12 commits into from
Jul 28, 2023
2 changes: 2 additions & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"bincount",
"cacheable",
"chisquare",
"commandline",
"conda",
"configspace",
"dataframe",
Expand Down Expand Up @@ -62,6 +63,7 @@
"skopt",
"smac",
"sqlalchemy",
"subcmd",
"subschema",
"subschemas",
"tolist",
Expand Down
91 changes: 73 additions & 18 deletions mlos_bench/mlos_bench/services/local/local_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import subprocess
import sys

from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING

from mlos_bench.services.base_service import Service
from mlos_bench.services.local.temp_dir_context import TempDirContextService
Expand All @@ -25,6 +25,41 @@
_LOG = logging.getLogger(__name__)


def split_cmdline(cmdline: str) -> Iterable[List[str]]:
"""
A single command line may contain multiple commands separated by
special characters (e.g., &&, ||, etc.) so further split the
commandline into an array of subcommand arrays.

Parameters
----------
cmdline: str
The commandline to split.

Yields
------
Iterable[List[str]]
A list of subcommands or separators, each one a list of tokens.
Can be rejoined as a flattened array.
"""
cmdline_tokens = shlex.shlex(cmdline, posix=True, punctuation_chars=True)
cmdline_tokens.whitespace_split = True
subcmd = []
for token in cmdline_tokens:
if token[0] not in cmdline_tokens.punctuation_chars:
subcmd.append(token)
else:
# Separator encountered. Yield any non-empty previous subcmd we accumulated.
if subcmd:
yield subcmd
# Also return the separators.
yield [token]
subcmd = []
# Return the trailing subcommand.
if subcmd:
yield subcmd


class LocalExecService(TempDirContextService, SupportsLocalExec):
"""
Collection of methods to run scripts and commands in an external process
Expand Down Expand Up @@ -98,6 +133,34 @@ def local_exec(self, script_lines: Iterable[str],

return (return_code, stdout, stderr)

def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]:
"""
Resolves local script path (first token) in the (sub)command line
tokens to its full path.

Parameters
----------
subcmd_tokens : List[str]
The previously split tokens of the subcmd.

Returns
-------
List[str]
A modified sub command line with the script paths resolved.
"""
script_path = self.config_loader_service.resolve_path(subcmd_tokens[0])
# Special case check for lone `.` which means both `source` and
# "current directory" (which isn't executable) in posix shells.
if os.path.exists(script_path) and os.path.isfile(script_path):
# If the script exists, use it.
subcmd_tokens[0] = os.path.abspath(script_path)
# Also check if it is a python script and prepend the currently
# executing python executable path to avoid requiring
# executable mode bits or a shebang.
if script_path.strip().lower().endswith(".py"):
subcmd_tokens.insert(0, sys.executable)
return subcmd_tokens

def _local_exec_script(self, script_line: str,
env_params: Optional[Mapping[str, "TunableValue"]],
cwd: str) -> Tuple[int, str, str]:
Expand All @@ -107,9 +170,7 @@ def _local_exec_script(self, script_line: str,
Parameters
----------
script_line : str
Line of the script to tun in the local process.
args : Iterable[str]
Command line arguments for the script.
Line of the script to run in the local process.
env_params : Mapping[str, Union[int, float, str]]
Environment variables.
cwd : str
Expand All @@ -120,18 +181,12 @@ def _local_exec_script(self, script_line: str,
(return_code, stdout, stderr) : (int, str, str)
A 3-tuple of return code, stdout, and stderr of the script process.
"""
cmd = shlex.split(script_line)
script_path = self.config_loader_service.resolve_path(cmd[0])
# special case handling for leading lone `.` character, which also means `source`
# (as in, include, but not execute) in shell syntax
if os.path.exists(script_path) and not os.path.isdir(script_path):
script_path = os.path.abspath(script_path)
else:
script_path = cmd[0] # rollback to the original value

cmd = [script_path] + cmd[1:]
if script_path.strip().lower().endswith(".py"):
cmd = [sys.executable] + cmd
# Split the command line into set of subcmd tokens.
# For each subcmd, perform path resolution fixups for any scripts being executed.
subcmds = split_cmdline(script_line)
subcmds = [self._resolve_cmdline_script_path(subcmd) for subcmd in subcmds]
# Finally recombine all of the fixed up subcmd tokens into the original.
cmd = [token for subcmd in subcmds for token in subcmd]
bpkroth marked this conversation as resolved.
Show resolved Hide resolved

env: Dict[str, str] = {}
if env_params:
Expand All @@ -144,12 +199,12 @@ def _local_exec_script(self, script_line: str,
env_copy.update(env)
env = env_copy

_LOG.info("Run: %s", cmd)

try:
if sys.platform != 'win32':
cmd = [" ".join(cmd)]

_LOG.info("Run: %s", cmd)

proc = subprocess.run(cmd, env=env or None, cwd=cwd, shell=True,
text=True, check=False, capture_output=True)

Expand Down
55 changes: 54 additions & 1 deletion mlos_bench/mlos_bench/tests/services/local/local_exec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
"""
Unit tests for the service to run the scripts locally.
"""
import os
import sys

import pytest
import pandas

from mlos_bench.services.local.local_exec import LocalExecService
from mlos_bench.services.local.local_exec import LocalExecService, split_cmdline
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.util import path_join

Expand All @@ -19,6 +20,32 @@
# `local_exec_service` fixture as both a function and a parameter.


def test_split_cmdline() -> None:
"""
Test splitting a commandline into subcommands.
"""
cmdline = ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)"
assert list(split_cmdline(cmdline)) == [
['.', 'env.sh'],
['&&'],
['('],
['echo', 'hello'],
['&&'],
['echo', 'world'],
['|'],
['tee'],
['>'],
['/tmp/test'],
['||'],
['echo', 'foo'],
['&&'],
['echo', '$var'],
[';'],
['true'],
[')'],
]


@pytest.fixture
def local_exec_service() -> LocalExecService:
"""
Expand All @@ -27,6 +54,22 @@ def local_exec_service() -> LocalExecService:
return LocalExecService(parent=ConfigPersistenceService())


def test_resolve_script(local_exec_service: LocalExecService) -> None:
"""
Test local script resolution logic with complex subcommand names.
"""
script = "os/linux/runtime/scripts/local/generate_kernel_config_script.py"
script_abspath = local_exec_service.config_loader_service.resolve_path(script)
orig_cmdline = f". env.sh && {script}"
expected_cmdline = f". env.sh && {script_abspath}"
subcmds_tokens = split_cmdline(orig_cmdline)
# pylint: disable=protected-access
subcmds_tokens = [local_exec_service._resolve_cmdline_script_path(subcmd_tokens) for subcmd_tokens in subcmds_tokens]
cmdline_tokens = [token for subcmd_tokens in subcmds_tokens for token in subcmd_tokens]
expanded_cmdline = " ".join(cmdline_tokens)
assert expanded_cmdline == expected_cmdline


def test_run_script(local_exec_service: LocalExecService) -> None:
"""
Run a script locally and check the results.
Expand Down Expand Up @@ -85,6 +128,16 @@ def test_run_script_read_csv(local_exec_service: LocalExecService) -> None:
assert stdout.strip() == ""
assert stderr.strip() == ""

if sys.platform == 'win32':
# On Windows we need to remove the trailing ' ' from the CSV file
# that was written due to the way the Python recomposes shell
# expansions by reintroducing ' 's between lists of arguments.
with open(path_join(temp_dir, "output.csv"), "rt", encoding="utf-8") as fh_output:
lines = fh_output.readlines()
with open(path_join(temp_dir, "output.csv"), "wt", encoding="utf-8") as fh_output:
for line in lines:
fh_output.write(line.rstrip() + os.linesep)
bpkroth marked this conversation as resolved.
Show resolved Hide resolved

data = pandas.read_csv(path_join(temp_dir, "output.csv"))
assert all(data.col1 == [111, 333])
assert all(data.col2 == [222, 444])
Expand Down
Loading