Skip to content

Commit

Permalink
Add up-arrow through history
Browse files Browse the repository at this point in the history
  • Loading branch information
djcopley committed Dec 20, 2023
1 parent 8af95ba commit aa40f9c
Showing 1 changed file with 17 additions and 36 deletions.
53 changes: 17 additions & 36 deletions shelloracle/shelloracle.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,23 @@
import asyncio
import os
import sys
from pathlib import Path

from prompt_toolkit import PromptSession
from prompt_toolkit.application import create_app_session_from_tty
from prompt_toolkit.history import FileHistory

from .provider import get_provider

ollama = get_provider("ollama")()


class ShitPluginException(Exception):
...


async def prompt_user(default_prompt: str | None) -> None:
"""
Coroutine function to prompt the user for input, optionally preloading with a default value.
:param default_prompt: The default text to preload into the prompt
:return: None
"""
async def prompt_user(default_prompt: str | None = None) -> str:
with create_app_session_from_tty():
prompt_session = PromptSession()
prompt_session.output.write_raw("\033[E") # Can I do this with one of the builtin methods
query = await prompt_session.prompt_async("> ", default=default_prompt or "")
async for item in ollama.generate(query):
sys.stdout.write(item)


async def generate_response(query: str) -> None:
"""Generate a shell command based on the query.
:param query: The query to generate a shell command for
:return: None
"""
try:
...
except Exception as e:
raise ShitPluginException("Uh oh! Shit plugin alert. The LLM plugin you are using shit the bed.") from e
history_file = Path.home() / ".shelloracle_history"
prompt_session = PromptSession(history=FileHistory(str(history_file)))
# Can I do this with one of the builtin methods?
# I tried a few (including cursor_down) with limited success
prompt_session.output.write_raw("\033[E")
return await prompt_session.prompt_async("> ", default=default_prompt or "")


def get_query_from_pipe() -> str | None:
Expand All @@ -47,7 +26,7 @@ def get_query_from_pipe() -> str | None:
:raises ValueError: If the input is more than one line
:return: The query from the stdin pipe
"""
if os.isatty(0): # Return 'None' if nothing is in the pipe
if os.isatty(0): # Return 'None' if fd 0 is a tty (no pipe)
return None
if not (lines := sys.stdin.readlines()):
return None
Expand All @@ -69,12 +48,14 @@ async def shell_oracle() -> None:
:returns: None
"""
if query := get_query_from_pipe():
await generate_response(query)
return
provider = get_provider("ollama")()

if not (prompt := get_query_from_pipe()):
default_prompt = os.environ.get("SHOR_DEFAULT_PROMPT")
prompt = await prompt_user(default_prompt)

default_prompt = os.environ.get("SHOR_DEFAULT_PROMPT")
await prompt_user(default_prompt)
async for token in provider.generate(prompt):
sys.stdout.write(token)


def cli() -> None:
Expand Down

0 comments on commit aa40f9c

Please sign in to comment.