Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly annotate run functions #997

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ setuptools>56
# Debuggery
icecream>=2.1
# typing
mypy==0.971
mypy==1.10.0
types-PyYAML==6.0.12.4
typing-extensions>=4,<5
106 changes: 106 additions & 0 deletions invoke/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Union,
Sequence,
overload,
cast,
Optional,
Mapping,
)

if TYPE_CHECKING:
from invoke.runners import Promise, Result
from invoke.watchers import StreamWatcher

from typing_extensions import Protocol, TypedDict, Unpack, Literal

class _BaseRunParams(TypedDict, total=False):
dry: bool
echo: bool
echo_format: str
echo_stdin: Optional[bool]
encoding: Optional[str]
err_stream: IO
env: Mapping[str, str]
fallback: bool
hide: Optional[bool]
in_stream: Optional[IO]
out_stream: IO
pty: bool
replace_env: bool
shell: str
timeout: Optional[int]
warn: bool
watchers: Sequence["StreamWatcher"]

class RunParams(_BaseRunParams, total=False):
"""Parameters for Runner.run"""

asynchronous: bool
disown: bool

class RunFunction(Protocol):
"""A function that runs a command."""

@overload
def __call__(
self,
command: str,
*,
disown: Literal[True],
**kwargs: Unpack[_BaseRunParams],
) -> None:
...

@overload
def __call__(
self,
command: str,
*,
disown: bool,
**kwargs: Unpack[_BaseRunParams],
) -> Optional["Result"]:
...

@overload
def __call__(
self,
command: str,
*,
asynchronous: Literal[True],
**kwargs: Unpack[_BaseRunParams],
) -> "Promise":
...

@overload
def __call__(
self,
command: str,
*,
asynchronous: bool,
**kwargs: Unpack[_BaseRunParams],
) -> Union["Promise", "Result"]:
...

@overload
def __call__(
self,
command: str,
**kwargs: Unpack[_BaseRunParams],
) -> "Result":
...

def __call__(
self,
command: str,
**kwargs: Unpack[RunParams],
) -> Optional["Result"]:
...


def annotate_run_function(func: Callable[..., Any]) -> "RunFunction":
"""Add standard run function annotations to a function."""
return cast("RunFunction", func)
2 changes: 1 addition & 1 deletion invoke/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def add_task(
name = task.name
# XXX https://github.com/python/mypy/issues/1424
elif hasattr(task.body, "func_name"):
name = task.body.func_name # type: ignore
name = task.body.func_name
elif hasattr(task.body, "__name__"):
name = task.__name__
else:
Expand Down
25 changes: 11 additions & 14 deletions invoke/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from unittest.mock import Mock

from ._types import annotate_run_function
from .config import Config, DataProxy
from .exceptions import Failure, AuthFailure, ResponseNotAccepted
from .runners import Result
Expand Down Expand Up @@ -87,6 +88,7 @@ def config(self, value: Config) -> None:
# runtime.
self._set(_config=value)

@annotate_run_function
def run(self, command: str, **kwargs: Any) -> Optional[Result]:
"""
Execute a local shell command, honoring config options.
Expand All @@ -106,12 +108,11 @@ def run(self, command: str, **kwargs: Any) -> Optional[Result]:
# NOTE: broken out of run() to allow for runner class injection in
# Fabric/etc, which needs to juggle multiple runner class types (local and
# remote).
def _run(
self, runner: "Runner", command: str, **kwargs: Any
) -> Optional[Result]:
def _run(self, runner: "Runner", command: str, **kwargs: Any) -> Optional[Result]:
command = self._prefix_commands(command)
return runner.run(command, **kwargs)

@annotate_run_function
def sudo(self, command: str, **kwargs: Any) -> Optional[Result]:
"""
Execute a shell command via ``sudo`` with password auto-response.
Expand Down Expand Up @@ -185,9 +186,7 @@ def sudo(self, command: str, **kwargs: Any) -> Optional[Result]:
return self._sudo(runner, command, **kwargs)

# NOTE: this is for runner injection; see NOTE above _run().
def _sudo(
self, runner: "Runner", command: str, **kwargs: Any
) -> Optional[Result]:
def _sudo(self, runner: "Runner", command: str, **kwargs: Any) -> Optional[Result]:
prompt = self.config.sudo.prompt
password = kwargs.pop("password", self.config.sudo.password)
user = kwargs.pop("user", self.config.sudo.user)
Expand Down Expand Up @@ -485,9 +484,7 @@ def __init__(self, config: Optional[Config] = None, **kwargs: Any) -> None:
if isinstance(results, dict):
for key, value in results.items():
results[key] = self._normalize(value)
elif isinstance(results, singletons) or hasattr(
results, "__iter__"
):
elif isinstance(results, singletons) or hasattr(results, "__iter__"):
results = self._normalize(results)
# Unknown input value: cry
else:
Expand Down Expand Up @@ -548,23 +545,23 @@ def _yield_result(self, attname: str, command: str) -> Result:
# raise_from(NotImplementedError(command), None)
raise NotImplementedError(command)

def run(self, command: str, *args: Any, **kwargs: Any) -> Result:
@annotate_run_function
def run(self, command: str, **kwargs: Any) -> Result:
# TODO: perform more convenience stuff associating args/kwargs with the
# result? E.g. filling in .command, etc? Possibly useful for debugging
# if one hits unexpected-order problems with what they passed in to
# __init__.
return self._yield_result("__run", command)

def sudo(self, command: str, *args: Any, **kwargs: Any) -> Result:
@annotate_run_function
def sudo(self, command: str, **kwargs: Any) -> Result:
# TODO: this completely nukes the top-level behavior of sudo(), which
# could be good or bad, depending. Most of the time I think it's good.
# No need to supply dummy password config, etc.
# TODO: see the TODO from run() re: injecting arg/kwarg values
return self._yield_result("__sudo", command)

def set_result_for(
self, attname: str, command: str, result: Result
) -> None:
def set_result_for(self, attname: str, command: str, result: Result) -> None:
"""
Modify the stored mock results for given ``attname`` (e.g. ``run``).

Expand Down
42 changes: 14 additions & 28 deletions invoke/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
except ImportError:
termios = None # type: ignore[assignment]

from ._types import annotate_run_function
from .exceptions import (
UnexpectedExit,
Failure,
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(self, context: "Context") -> None:
self._asynchronous = False
self._disowned = False

@annotate_run_function
def run(self, command: str, **kwargs: Any) -> Optional["Result"]:
"""
Execute ``command``, returning an instance of `Result` once complete.
Expand Down Expand Up @@ -407,9 +409,7 @@ def _setup(self, command: str, kwargs: Any) -> None:
# Normalize kwargs w/ config; sets self.opts, self.streams
self._unify_kwargs_with_config(kwargs)
# Environment setup
self.env = self.generate_env(
self.opts["env"], self.opts["replace_env"]
)
self.env = self.generate_env(self.opts["env"], self.opts["replace_env"])
# Arrive at final encoding if neither config nor kwargs had one
self.encoding = self.opts["encoding"] or self.default_encoding()
# Echo running command (wants to be early to be included in dry-run)
Expand Down Expand Up @@ -544,7 +544,9 @@ def _unify_kwargs_with_config(self, kwargs: Any) -> None:
self._asynchronous = opts["asynchronous"]
self._disowned = opts["disown"]
if self._asynchronous and self._disowned:
err = "Cannot give both 'asynchronous' and 'disown' at the same time!" # noqa
err = (
"Cannot give both 'asynchronous' and 'disown' at the same time!" # noqa
)
raise ValueError(err)
# If hide was True, turn off echoing
if opts["hide"] is True:
Expand Down Expand Up @@ -600,9 +602,7 @@ def _collate_result(self, watcher_errors: List[WatcherError]) -> "Result":
# TODO: as noted elsewhere, I kinda hate this. Consider changing
# generate_result()'s API in next major rev so we can tidy up.
result = self.generate_result(
**dict(
self.result_kwargs, stdout=stdout, stderr=stderr, exited=exited
)
**dict(self.result_kwargs, stdout=stdout, stderr=stderr, exited=exited)
)
return result

Expand Down Expand Up @@ -753,9 +753,7 @@ def _handle_output(
# Run our specific buffer through the autoresponder framework
self.respond(buffer_)

def handle_stdout(
self, buffer_: List[str], hide: bool, output: IO
) -> None:
def handle_stdout(self, buffer_: List[str], hide: bool, output: IO) -> None:
"""
Read process' stdout, storing into a buffer & printing/parsing.

Expand All @@ -772,13 +770,9 @@ def handle_stdout(

.. versionadded:: 1.0
"""
self._handle_output(
buffer_, hide, output, reader=self.read_proc_stdout
)
self._handle_output(buffer_, hide, output, reader=self.read_proc_stdout)

def handle_stderr(
self, buffer_: List[str], hide: bool, output: IO
) -> None:
def handle_stderr(self, buffer_: List[str], hide: bool, output: IO) -> None:
"""
Read process' stderr, storing into a buffer & printing/parsing.

Expand All @@ -787,9 +781,7 @@ def handle_stderr(

.. versionadded:: 1.0
"""
self._handle_output(
buffer_, hide, output, reader=self.read_proc_stderr
)
self._handle_output(buffer_, hide, output, reader=self.read_proc_stderr)

def read_our_stdin(self, input_: IO) -> Optional[str]:
"""
Expand Down Expand Up @@ -938,9 +930,7 @@ def respond(self, buffer_: List[str]) -> None:
for response in watcher.submit(stream):
self.write_proc_stdin(response)

def generate_env(
self, env: Dict[str, Any], replace_env: bool
) -> Dict[str, Any]:
def generate_env(self, env: Dict[str, Any], replace_env: bool) -> Dict[str, Any]:
"""
Return a suitable environment dict based on user input & behavior.

Expand Down Expand Up @@ -1281,9 +1271,7 @@ def _write_proc_stdin(self, data: bytes) -> None:
elif self.process and self.process.stdin:
fd = self.process.stdin.fileno()
else:
raise SubprocessPipeError(
"Unable to write to missing subprocess or stdin!"
)
raise SubprocessPipeError("Unable to write to missing subprocess or stdin!")
# Try to write, ignoring broken pipes if encountered (implies child
# process exited before the process piping stdin to us finished;
# there's nothing we can do about that!)
Expand All @@ -1301,9 +1289,7 @@ def close_proc_stdin(self) -> None:
elif self.process and self.process.stdin:
self.process.stdin.close()
else:
raise SubprocessPipeError(
"Unable to close missing subprocess or stdin!"
)
raise SubprocessPipeError("Unable to close missing subprocess or stdin!")

def start(self, command: str, shell: str, env: Dict[str, Any]) -> None:
if self.using_pty:
Expand Down
4 changes: 2 additions & 2 deletions invoke/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def run(self) -> None:
# doesn't appear to be the case, then assume we're being used
# directly and just use super() ourselves.
# XXX https://github.com/python/mypy/issues/1424
if hasattr(self, "_run") and callable(self._run): # type: ignore
if hasattr(self, "_run") and callable(self._run):
# TODO: this could be:
# - io worker with no 'result' (always local)
# - tunnel worker, also with no 'result' (also always local)
Expand All @@ -206,7 +206,7 @@ def run(self) -> None:
# and let it continue acting like a normal thread (meh)
# - assume the run/sudo/etc case will use a queue inside its
# worker body, orthogonal to how exception handling works
self._run() # type: ignore
self._run()
else:
super().run()
except BaseException:
Expand Down