diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py index 5a9dbc3bc..2f3b88c13 100644 --- a/pr_agent/algo/ai_handlers/langchain_ai_handler.py +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -20,35 +20,13 @@ def __init__(self): # Initialize OpenAIHandler specific attributes here super().__init__() self.azure = get_settings().get("OPENAI.API_TYPE", "").lower() == "azure" - try: - if self.azure: - # using a partial function so we can set the deployment_id later to support fallback_deployments - # but still need to access the other settings now so we can raise a proper exception if they're missing - self._chat = functools.partial( - lambda **kwargs: AzureChatOpenAI(**kwargs), - openai_api_key=get_settings().openai.key, - openai_api_base=get_settings().openai.api_base, - openai_api_version=get_settings().openai.api_version, - ) - else: - # for llms that compatible with openai, should use custom api base - openai_api_base = get_settings().get("OPENAI.API_BASE", None) - if openai_api_base is None or len(openai_api_base) == 0: - self._chat = ChatOpenAI(openai_api_key=get_settings().openai.key) - else: - self._chat = ChatOpenAI(openai_api_key=get_settings().openai.key, openai_api_base=openai_api_base) - except AttributeError as e: - if getattr(e, "name"): - raise ValueError(f"OpenAI {e.name} is required") from e - else: - raise e + + # Create a default unused chat object to trigger early validation + self._create_chat(self.deployment_id) def chat(self, messages: list, model: str, temperature: float): - if self.azure: - # we must set the deployment_id only here (instead of the __init__ method) to support fallback_deployments - return self._chat.invoke(input = messages, model=model, temperature=temperature, deployment_name=self.deployment_id) - else: - return self._chat.invoke(input = messages, model=model, temperature=temperature) + chat = self._create_chat(self.deployment_id) + return chat.invoke(input=messages, model=model, temperature=temperature) @property def deployment_id(self): @@ -71,3 +49,28 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: except (Exception) as e: get_logger().error("Unknown error during OpenAI inference: ", e) raise e + + def _create_chat(self, deployment_id=None): + try: + if self.azure: + # using a partial function so we can set the deployment_id later to support fallback_deployments + # but still need to access the other settings now so we can raise a proper exception if they're missing + return AzureChatOpenAI( + openai_api_key=get_settings().openai.key, + openai_api_version=get_settings().openai.api_version, + azure_deployment=deployment_id, + azure_endpoint=get_settings().openai.api_base, + ) + else: + # for llms that compatible with openai, should use custom api base + openai_api_base = get_settings().get("OPENAI.API_BASE", None) + if openai_api_base is None or len(openai_api_base) == 0: + return ChatOpenAI(openai_api_key=get_settings().openai.key) + else: + return ChatOpenAI(openai_api_key=get_settings().openai.key, openai_api_base=openai_api_base) + except AttributeError as e: + if getattr(e, "name"): + raise ValueError(f"OpenAI {e.name} is required") from e + else: + raise e + diff --git a/requirements.txt b/requirements.txt index f9c976f15..4337fe344 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,6 @@ gunicorn==20.1.0 # pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main # lancedb==0.5.1 # uncomment this to support language LangChainOpenAIHandler -# langchain==0.0.349 +# langchain==0.2.0 +# langchain-core==0.2.28 +# langchain-openai==0.1.20