diff --git a/pyproject.toml b/pyproject.toml index fe946af8..14ac935b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.0.139" +version = "0.0.140" description = "UiPath Langchain" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.10" diff --git a/src/uipath_langchain/chat/models.py b/src/uipath_langchain/chat/models.py index 044cbbb1..133f71d7 100644 --- a/src/uipath_langchain/chat/models.py +++ b/src/uipath_langchain/chat/models.py @@ -1,15 +1,15 @@ import json import logging -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models import LanguageModelInput -from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.messages.ai import UsageMetadata -from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable from langchain_openai.chat_models import AzureChatOpenAI from pydantic import BaseModel @@ -49,6 +49,54 @@ async def _agenerate( response = await self._acall(self.url, payload, self.auth_headers) return self._create_chat_result(response) + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + if "tools" in kwargs and not kwargs["tools"]: + del kwargs["tools"] + payload = self._get_request_payload(messages, stop=stop, **kwargs) + response = self._call(self.url, payload, self.auth_headers) + + # For non-streaming response, yield single chunk + chat_result = self._create_chat_result(response) + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content=chat_result.generations[0].message.content, + additional_kwargs=chat_result.generations[0].message.additional_kwargs, + response_metadata=chat_result.generations[0].message.response_metadata, + usage_metadata=chat_result.generations[0].message.usage_metadata, # type: ignore + ) + ) + yield chunk + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + if "tools" in kwargs and not kwargs["tools"]: + del kwargs["tools"] + payload = self._get_request_payload(messages, stop=stop, **kwargs) + response = await self._acall(self.url, payload, self.auth_headers) + + # For non-streaming response, yield single chunk + chat_result = self._create_chat_result(response) + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content=chat_result.generations[0].message.content, + additional_kwargs=chat_result.generations[0].message.additional_kwargs, + response_metadata=chat_result.generations[0].message.response_metadata, + usage_metadata=chat_result.generations[0].message.usage_metadata, # type: ignore + ) + ) + yield chunk + def with_structured_output( self, schema: Optional[Any] = None, @@ -217,6 +265,92 @@ async def _agenerate( response = await self._acall(self.url, payload, self.auth_headers) return self._create_chat_result(response) + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the LLM on a given prompt. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + run_manager: A run manager with callbacks for the LLM. + **kwargs: Additional keyword arguments. + + Returns: + An iterator of ChatGenerationChunk objects. + """ + if kwargs.get("tools"): + kwargs["tools"] = [tool["function"] for tool in kwargs["tools"]] + if "tool_choice" in kwargs and kwargs["tool_choice"]["type"] == "function": + kwargs["tool_choice"] = { + "type": "tool", + "name": kwargs["tool_choice"]["function"]["name"], + } + payload = self._get_request_payload(messages, stop=stop, **kwargs) + response = self._call(self.url, payload, self.auth_headers) + + # For non-streaming response, yield single chunk + chat_result = self._create_chat_result(response) + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content=chat_result.generations[0].message.content, + additional_kwargs=chat_result.generations[0].message.additional_kwargs, + response_metadata=chat_result.generations[0].message.response_metadata, + usage_metadata=chat_result.generations[0].message.usage_metadata, # type: ignore + tool_calls=getattr( + chat_result.generations[0].message, "tool_calls", None + ), + ) + ) + yield chunk + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + """Async stream the LLM on a given prompt. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + run_manager: A run manager with callbacks for the LLM. + **kwargs: Additional keyword arguments. + + Returns: + An async iterator of ChatGenerationChunk objects. + """ + if kwargs.get("tools"): + kwargs["tools"] = [tool["function"] for tool in kwargs["tools"]] + if "tool_choice" in kwargs and kwargs["tool_choice"]["type"] == "function": + kwargs["tool_choice"] = { + "type": "tool", + "name": kwargs["tool_choice"]["function"]["name"], + } + payload = self._get_request_payload(messages, stop=stop, **kwargs) + response = await self._acall(self.url, payload, self.auth_headers) + + # For non-streaming response, yield single chunk + chat_result = self._create_chat_result(response) + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content=chat_result.generations[0].message.content, + additional_kwargs=chat_result.generations[0].message.additional_kwargs, + response_metadata=chat_result.generations[0].message.response_metadata, + usage_metadata=chat_result.generations[0].message.usage_metadata, # type: ignore + tool_calls=getattr( + chat_result.generations[0].message, "tool_calls", None + ), + ) + ) + yield chunk + def with_structured_output( self, schema: Optional[Any] = None,