From ecb7b26e15e67196741a5495a9e541bc23c3253f Mon Sep 17 00:00:00 2001 From: Farkhod Sadykov Date: Sat, 17 Feb 2024 02:58:55 +0100 Subject: [PATCH] New config variable API_BASE_URL and minor bug fixes (#477) --- README.md | 4 ++-- pyproject.toml | 2 +- sgpt/__main__.py | 3 +++ sgpt/__version__.py | 2 +- sgpt/config.py | 2 +- sgpt/handlers/chat_handler.py | 10 ++++++---- sgpt/handlers/handler.py | 6 ++++++ tests/utils.py | 2 ++ 8 files changed, 22 insertions(+), 9 deletions(-) create mode 100644 sgpt/__main__.py diff --git a/README.md b/README.md index d08d763b..2b642ef5 100644 --- a/README.md +++ b/README.md @@ -382,8 +382,8 @@ You can setup some parameters in runtime configuration file `~/.config/shell_gpt ```text # API key, also it is possible to define OPENAI_API_KEY env. OPENAI_API_KEY=your_api_key -# OpenAI host, useful if you would like to use proxy. -OPENAI_API_HOST=https://api.openai.com +# Base URL of the backend server. If "default" URL will be resolved based on --model. +API_BASE_URL=default # Max amount of cached message per chat session. CHAT_CACHE_LENGTH=100 # Chat cache folder. diff --git a/pyproject.toml b/pyproject.toml index 82726a9f..e97586ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] dependencies = [ - "litellm >= 1.20.1, < 2.0.0", + "litellm == 1.24.5", "typer >= 0.7.0, < 1.0.0", "click >= 7.1.1, < 9.0.0", "rich >= 13.1.0, < 14.0.0", diff --git a/sgpt/__main__.py b/sgpt/__main__.py new file mode 100644 index 00000000..1c7bc113 --- /dev/null +++ b/sgpt/__main__.py @@ -0,0 +1,3 @@ +from .app import entry_point + +entry_point() diff --git a/sgpt/__version__.py b/sgpt/__version__.py index 67bc602a..9c73af26 100644 --- a/sgpt/__version__.py +++ b/sgpt/__version__.py @@ -1 +1 @@ -__version__ = "1.3.0" +__version__ = "1.3.1" diff --git a/sgpt/config.py b/sgpt/config.py index f7ebf34d..20c71048 100644 --- a/sgpt/config.py +++ b/sgpt/config.py @@ -23,7 +23,6 @@ "CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")), "REQUEST_TIMEOUT": int(os.getenv("REQUEST_TIMEOUT", "60")), "DEFAULT_MODEL": os.getenv("DEFAULT_MODEL", "gpt-4-1106-preview"), - "OPENAI_BASE_URL": os.getenv("OPENAI_API_HOST", "https://api.openai.com/v1"), "DEFAULT_COLOR": os.getenv("DEFAULT_COLOR", "magenta"), "ROLE_STORAGE_PATH": os.getenv("ROLE_STORAGE_PATH", str(ROLE_STORAGE_PATH)), "DEFAULT_EXECUTE_SHELL_CMD": os.getenv("DEFAULT_EXECUTE_SHELL_CMD", "false"), @@ -32,6 +31,7 @@ "OPENAI_FUNCTIONS_PATH": os.getenv("OPENAI_FUNCTIONS_PATH", str(FUNCTIONS_PATH)), "OPENAI_USE_FUNCTIONS": os.getenv("OPENAI_USE_FUNCTIONS", "true"), "SHOW_FUNCTIONS_OUTPUT": os.getenv("SHOW_FUNCTIONS_OUTPUT", "false"), + "API_BASE_URL": os.getenv("API_BASE_URL", "default"), # New features might add their own config variables here. } diff --git a/sgpt/handlers/chat_handler.py b/sgpt/handlers/chat_handler.py index b315bd34..8704507d 100644 --- a/sgpt/handlers/chat_handler.py +++ b/sgpt/handlers/chat_handler.py @@ -128,18 +128,20 @@ def list_ids(cls, value: str) -> None: @classmethod def show_messages(cls, chat_id: str) -> None: + color = cfg.get("DEFAULT_COLOR") if "APPLY MARKDOWN" in cls.initial_message(chat_id): + theme = cfg.get("CODE_THEME") for message in cls.chat_session.get_messages(chat_id): if message.startswith("assistant:"): - Console().print(Markdown(message)) + Console().print(Markdown(message, code_theme=theme)) else: - typer.secho(message, fg=cfg.get("DEFAULT_COLOR")) + typer.secho(message, fg=color) typer.echo() return for index, message in enumerate(cls.chat_session.get_messages(chat_id)): - color = "magenta" if index % 2 == 0 else "green" - typer.secho(message, fg=color) + running_color = color if index % 2 == 0 else "green" + typer.secho(message, fg=running_color) @classmethod @option_callback diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index 8cecc56d..0cc7dc40 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -19,6 +19,10 @@ class Handler: def __init__(self, role: SystemRole) -> None: self.role = role + api_base_url = cfg.get("API_BASE_URL") + self.base_url = None if api_base_url == "default" else api_base_url + self.timeout = int(cfg.get("REQUEST_TIMEOUT")) + @property def printer(self) -> Printer: use_markdown = "APPLY MARKDOWN" in self.role.role @@ -78,6 +82,8 @@ def get_completion( functions=functions, stream=True, api_key=cfg.get("OPENAI_API_KEY"), + base_url=self.base_url, + timeout=self.timeout, ): delta = chunk.choices[0].delta function_call = delta.get("function_call") diff --git a/tests/utils.py b/tests/utils.py index ed0998c1..f2a022f7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -41,5 +41,7 @@ def comp_args(role, prompt, **kwargs): "functions": None, "stream": True, "api_key": ANY, + "base_url": ANY, + "timeout": ANY, **kwargs, }