Skip to content

Commit

Permalink
Several typing fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanslenders committed May 16, 2024
1 parent 3ec97d7 commit c1a4310
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
14 changes: 6 additions & 8 deletions ptpython/python_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,6 @@ def __init__(
"classic": ClassicPrompt(),
}

self.get_input_prompt = lambda: self.all_prompt_styles[
self.prompt_style
].in_prompt()

self.get_output_prompt = lambda: self.all_prompt_styles[
self.prompt_style
].out_prompt()

#: Load styles.
self.code_styles: dict[str, BaseStyle] = get_all_code_styles()
self.ui_styles = get_all_ui_styles()
Expand Down Expand Up @@ -425,6 +417,12 @@ def __init__(
else:
self._app = None

def get_input_prompt(self) -> AnyFormattedText:
return self.all_prompt_styles[self.prompt_style].in_prompt()

def get_output_prompt(self) -> AnyFormattedText:
return self.all_prompt_styles[self.prompt_style].out_prompt()

def _accept_handler(self, buff: Buffer) -> bool:
app = get_app()
app.exit(result=buff.text)
Expand Down
17 changes: 9 additions & 8 deletions ptpython/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import types
import warnings
from dis import COMPILER_FLAG_NAMES
from typing import Any, Callable, ContextManager, Iterable
from pathlib import Path
from typing import Any, Callable, ContextManager, Iterable, Sequence

from prompt_toolkit.formatted_text import OneStyleAndTextTuple
from prompt_toolkit.patch_stdout import patch_stdout as patch_stdout_context
Expand Down Expand Up @@ -64,7 +65,7 @@ def _has_coroutine_flag(code: types.CodeType) -> bool:

class PythonRepl(PythonInput):
def __init__(self, *a, **kw) -> None:
self._startup_paths = kw.pop("startup_paths", None)
self._startup_paths: Sequence[str | Path] | None = kw.pop("startup_paths", None)
super().__init__(*a, **kw)
self._load_start_paths()

Expand Down Expand Up @@ -348,7 +349,7 @@ def _store_eval_result(self, result: object) -> None:
def get_compiler_flags(self) -> int:
return super().get_compiler_flags() | PyCF_ALLOW_TOP_LEVEL_AWAIT

def _compile_with_flags(self, code: str, mode: str):
def _compile_with_flags(self, code: str, mode: str) -> Any:
"Compile code with the right compiler flags."
return compile(
code,
Expand Down Expand Up @@ -459,13 +460,13 @@ def enter_to_continue() -> None:


def embed(
globals=None,
locals=None,
globals: dict[str, Any] | None = None,
locals: dict[str, Any] | None = None,
configure: Callable[[PythonRepl], None] | None = None,
vi_mode: bool = False,
history_filename: str | None = None,
title: str | None = None,
startup_paths=None,
startup_paths: Sequence[str | Path] | None = None,
patch_stdout: bool = False,
return_asyncio_coroutine: bool = False,
) -> None:
Expand Down Expand Up @@ -494,10 +495,10 @@ def embed(

locals = locals or globals

def get_globals():
def get_globals() -> dict[str, Any]:
return globals

def get_locals():
def get_locals() -> dict[str, Any]:
return locals

# Create REPL.
Expand Down

0 comments on commit c1a4310

Please sign in to comment.