diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 9a48cdc3d..f45dc71dd 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -3,8 +3,10 @@ import litellm import openai from litellm import acompletion -from openai.error import APIError, RateLimitError, Timeout, TryAgain +from openai import APIError, RateLimitError from retry import retry +from tenacity import TryAgain + from pr_agent.config_loader import get_settings from pr_agent.log import get_logger @@ -24,7 +26,7 @@ def __init__(self): Raises a ValueError if the OpenAI key is missing. """ self.azure = False - + self.api_base = None if get_settings().get("OPENAI.KEY", None): openai.api_key = get_settings().openai.key litellm.openai_key = get_settings().openai.key @@ -53,8 +55,9 @@ def __init__(self): litellm.replicate_key = get_settings().replicate.key if get_settings().get("HUGGINGFACE.KEY", None): litellm.huggingface_key = get_settings().huggingface.key - if get_settings().get("HUGGINGFACE.API_BASE", None): - litellm.api_base = get_settings().huggingface.api_base + if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model: + litellm.api_base = get_settings().huggingface.api_base + self.api_base = get_settings().huggingface.api_base if get_settings().get("VERTEXAI.VERTEX_PROJECT", None): litellm.vertex_project = get_settings().vertexai.vertex_project litellm.vertex_location = get_settings().get( @@ -68,22 +71,29 @@ def deployment_id(self): """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) - @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), - tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + @retry( + exceptions=(APIError, TryAgain, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, + delay=2, + backoff=2, + jitter=(1, 3), + ) + async def chat_completion( + self, model: str, system: str, user: str, temperature: float = 0.2 + ): """ Performs a chat completion using the OpenAI ChatCompletion API. Retries in case of API errors or timeouts. - + Args: model (str): The model to use for chat completion. temperature (float): The temperature parameter for chat completion. system (str): The system message for chat completion. user (str): The user message for chat completion. - + Returns: tuple: A tuple containing the response and finish reason from the API. - + Raises: TryAgain: If the API response is empty or there are no choices in the response. APIError: If there is an error during OpenAI inference. @@ -99,28 +109,52 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: ) if self.azure: model = 'azure/' + model + + system = get_settings().get("CONFIG.MODEL_SYSTEM_PREFIX", "") + system + \ + get_settings().get("CONFIG.MODEL_SYSTEM_SUFFIX", "") + suffix = '' + yaml_start = '```yaml' + if user.endswith(yaml_start): + user = user[:-len(yaml_start)] + suffix = '\n' + yaml_start + '\n' + user = get_settings().get("CONFIG.MODEL_USER_PREFIX", "") + user + \ + get_settings().get("CONFIG.MODEL_USER_SUFFIX", "") + suffix messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + stop = get_settings().get("CONFIG.MODEL_STOP", None) response = await acompletion( model=model, deployment_id=deployment_id, messages=messages, temperature=temperature, - force_timeout=get_settings().config.ai_timeout + force_timeout=get_settings().config.ai_timeout, + api_base=self.api_base, + stop=stop, ) - except (APIError, Timeout, TryAgain) as e: + except (APIError, TryAgain) as e: get_logger().error("Error during OpenAI inference: ", e) raise - except (RateLimitError) as e: + except RateLimitError as e: get_logger().error("Rate limit error during OpenAI inference: ", e) raise - except (Exception) as e: + except Exception as e: get_logger().error("Unknown error during OpenAI inference: ", e) raise TryAgain from e if response is None or len(response["choices"]) == 0: raise TryAgain - resp = response["choices"][0]['message']['content'] + resp = response["choices"][0]["message"]["content"] + if stop: + for stop_word in stop: + if resp.endswith(stop_word): + resp = resp[:-len(stop_word)] + break finish_reason = response["choices"][0]["finish_reason"] usage = response.get("usage") - get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, - model=model, usage=usage) + get_logger().info( + "AI response", + response=resp, + messages=messages, + finish_reason=finish_reason, + model=model, + usage=usage, + ) return resp, finish_reason diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 7a6e666c4..eb136bd9e 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -375,7 +375,7 @@ def get_user_labels(current_labels: List[str] = None): def get_max_tokens(model): settings = get_settings() - max_tokens_model = MAX_TOKENS[model] + max_tokens_model = MAX_TOKENS.get(model, 4000) if settings.config.max_model_tokens: max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model) # get_logger().debug(f"limiting max tokens to {max_tokens_model}") diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 1fb851644..bad1c44fb 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -327,7 +327,8 @@ def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]: def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: try: - self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id) + if reaction_id: + self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id) return True except Exception as e: get_logger().exception(f"Failed to remove eyes reaction, error: {e}") diff --git a/requirements.txt b/requirements.txt index eae08f4cc..8cedbbf10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ dynaconf==3.1.12 fastapi==0.99.0 PyGithub==1.59.* retry==0.9.2 -openai==0.27.8 +openai==1.3.5 Jinja2==3.1.2 tiktoken==0.4.0 uvicorn==0.22.0 @@ -13,7 +13,6 @@ atlassian-python-api==3.39.0 GitPython==3.1.32 PyYAML==6.0 starlette-context==0.3.6 -litellm==0.12.5 boto3==1.28.25 google-cloud-storage==2.10.0 ujson==5.8.0 @@ -23,3 +22,5 @@ pinecone-client pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main loguru==0.7.2 google-cloud-aiplatform==1.35.0 +litellm==1.7.1 +tenacity==8.1.0