Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/goose/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from goose.utils.autocomplete import SUPPORTED_SHELLS, setup_autocomplete
from goose.utils.session_file import list_sorted_session_files

LOG_LEVELS = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
LOG_CHOICE = click.Choice(LOG_LEVELS)


@click.group()
def goose_cli() -> None:
Expand Down Expand Up @@ -135,7 +138,7 @@ def get_session_files() -> dict[str, Path]:
@click.argument("name", required=False, shell_complete=autocomplete_session_files)
@click.option("--profile")
@click.option("--plan", type=click.Path(exists=True))
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
@click.option("--log-level", type=LOG_CHOICE, default="INFO")
def session_start(name: Optional[str], profile: str, log_level: str, plan: Optional[str] = None) -> None:
"""Start a new goose session"""
if plan:
Expand All @@ -161,7 +164,7 @@ def parse_args(ctx: click.Context, param: click.Parameter, value: str) -> dict[s

@session.command(name="planned")
@click.option("--plan", type=click.Path(exists=True))
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
@click.option("--log-level", type=LOG_CHOICE, default="INFO")
@click.option("-a", "--args", callback=parse_args, help="Args in the format arg1:value1,arg2:value2")
def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) -> None:
plan_templated = render_template(Path(plan), context=args)
Expand All @@ -173,7 +176,7 @@ def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) -
@session.command(name="resume")
@click.argument("name", required=False, shell_complete=autocomplete_session_files)
@click.option("--profile")
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
@click.option("--log-level", type=LOG_CHOICE, default="INFO")
def session_resume(name: Optional[str], profile: str, log_level: str) -> None:
"""Resume an existing goose session"""
session_files = get_session_files()
Expand All @@ -190,13 +193,13 @@ def session_resume(name: Optional[str], profile: str, log_level: str) -> None:
else:
print(f"Creating new session: {name}")
session = Session(name=name, profile=profile, log_level=log_level)
session.run()
session.run(new_session=False)


@goose_cli.command(name="run")
@click.argument("message_file", required=False, type=click.Path(exists=True))
@click.option("--profile")
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
@click.option("--log-level", type=LOG_CHOICE, default="INFO")
def run(message_file: Optional[str], profile: str, log_level: str) -> None:
"""Run a single-pass session with a message from a markdown input file"""
if message_file:
Expand Down
6 changes: 4 additions & 2 deletions src/goose/cli/prompt/overwrite_session_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ def __init__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
self.default = "resume"

def check_choice(self, choice: str) -> bool:
normalized_choice = choice.lower()
for key in self.choices:
normalized_choice = choice.lower()
if normalized_choice == key or normalized_choice[0] == key[0]:
is_key = normalized_choice == key
is_first_letter = normalized_choice and normalized_choice[0] == key[0]
if is_key or is_first_letter:
return True
return False

Expand Down
7 changes: 5 additions & 2 deletions src/goose/cli/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,15 @@ def single_pass(self, initial_message: str) -> None:
print(f"[dim]ended run | name:[cyan]{self.name}[/] profile:[cyan]{profile}[/]")
print(f"[dim]to resume: [magenta]goose session resume {self.name} --profile {profile}[/][/]")

def run(self) -> None:
def run(self, new_session: bool = True) -> None:
"""
Runs the main loop to handle user inputs and responses.
Continues until an empty string is returned from the prompt.

Args:
new_session (bool): True when starting a new session, False when resuming.
"""
if is_existing_session(self.session_file_path):
if is_existing_session(self.session_file_path) and new_session:
self._prompt_overwrite_session()

profile_name = self.profile_name or "default"
Expand Down
49 changes: 49 additions & 0 deletions tests/cli/prompt/test_overwrite_session_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt


@pytest.fixture
def prompt():
return OverwriteSessionPrompt()


def test_init(prompt):
assert prompt.choices == {
"yes": "Overwrite the existing session",
"no": "Pick a new session name",
"resume": "Resume the existing session",
}
assert prompt.default == "resume"


@pytest.mark.parametrize(
"choice, expected",
[
("", False),
("invalid", False),
("n", True),
("N", True),
("no", True),
("NO", True),
("r", True),
("R", True),
("resume", True),
("RESUME", True),
("y", True),
("Y", True),
("yes", True),
("YES", True),
],
)
def test_check_choice(prompt, choice, expected):
assert prompt.check_choice(choice) == expected


def test_instantiation():
prompt = OverwriteSessionPrompt()
assert prompt.choices == {
"yes": "Overwrite the existing session",
"no": "Pick a new session name",
"resume": "Resume the existing session",
}
assert prompt.default == "resume"
25 changes: 17 additions & 8 deletions tests/cli/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,20 @@ def test_set_generated_session_name(mock_droid, create_session_with_mock_configs
def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs):
session = create_session_with_mock_configs({"name": SESSION_NAME})

mock_is_existing.return_value = True
session.run()
mock_prompt.assert_called_once()

mock_prompt.reset_mock()
mock_is_existing.return_value = False
session.run()
mock_prompt.assert_not_called()
def check_prompt_behavior(is_existing, new_session, should_prompt):
mock_is_existing.return_value = is_existing
if new_session is None:
session.run()
else:
session.run(new_session=new_session)

if should_prompt:
mock_prompt.assert_called_once()
else:
mock_prompt.assert_not_called()
mock_prompt.reset_mock()

check_prompt_behavior(is_existing=True, new_session=None, should_prompt=True)
check_prompt_behavior(is_existing=False, new_session=None, should_prompt=False)
check_prompt_behavior(is_existing=True, new_session=True, should_prompt=True)
check_prompt_behavior(is_existing=False, new_session=False, should_prompt=False)