Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis committed Nov 22, 2024
2 parents a248164 + d7e61a5 commit 6836e5e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 108 deletions.
102 changes: 39 additions & 63 deletions libs/partners/ollama/langchain_ollama/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,22 @@ class Multiply(BaseModel):
The async client to use for making requests.
"""

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama."""
return {
"model": self.model,
"format": self.format,
"options": {
def _chat_params(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
ollama_messages = self._convert_messages_to_ollama_messages(messages)

if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop

options_dict = kwargs.pop(
"options",
{
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
Expand All @@ -359,14 +368,31 @@ def _default_params(self) -> Dict[str, Any]:
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"seed": self.seed,
"stop": self.stop,
"stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
},
"keep_alive": self.keep_alive,
)

tools = kwargs.get("tools")
default_stream = not bool(tools)

params = {
"messages": ollama_messages,
"stream": kwargs.pop("stream", default_stream),
"model": kwargs.pop("model", self.model),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
}

if tools:
params["tools"] = tools

return params

@model_validator(mode="after")
def _set_clients(self) -> Self:
"""Set clients to use for ollama."""
Expand Down Expand Up @@ -464,34 +490,9 @@ async def _acreate_chat_stream(
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
ollama_messages = self._convert_messages_to_ollama_messages(messages)

stop = stop if stop is not None else self.stop

params = self._default_params
chat_params = self._chat_params(messages, stop, **kwargs)

for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]

params["options"]["stop"] = stop

tools = kwargs.get("tools", None)
stream = tools is None or len(tools) == 0

chat_params = {
"model": params["model"],
"messages": ollama_messages,
"stream": stream,
"options": Options(**params["options"]),
"keep_alive": params["keep_alive"],
"format": params["format"],
}

if tools is not None:
chat_params["tools"] = tools

if stream:
if chat_params["stream"]:
async for part in await self._async_client.chat(**chat_params):
yield part
else:
Expand All @@ -503,34 +504,9 @@ def _create_chat_stream(
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
ollama_messages = self._convert_messages_to_ollama_messages(messages)

stop = stop if stop is not None else self.stop

params = self._default_params

for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]

params["options"]["stop"] = stop

tools = kwargs.get("tools", None)
stream = tools is None or len(tools) == 0

chat_params = {
"model": params["model"],
"messages": ollama_messages,
"stream": stream,
"options": Options(**params["options"]),
"keep_alive": params["keep_alive"],
"format": params["format"],
}

if tools is not None:
chat_params["tools"] = tools
chat_params = self._chat_params(messages, stop, **kwargs)

if stream:
if chat_params["stream"]:
yield from self._client.chat(**chat_params)
else:
yield self._client.chat(**chat_params)
Expand Down
74 changes: 29 additions & 45 deletions libs/partners/ollama/langchain_ollama/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,20 @@ class OllamaLLM(BaseLLM):
The async client to use for making requests.
"""

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama."""
return {
"model": self.model,
"format": self.format,
"options": {
def _generate_params(
self,
prompt: str,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop

options_dict = kwargs.pop(
"options",
{
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
Expand All @@ -143,14 +150,25 @@ def _default_params(self) -> Dict[str, Any]:
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"stop": self.stop,
"stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
},
"keep_alive": self.keep_alive,
)

params = {
"prompt": prompt,
"stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
}

return params

@property
def _llm_type(self) -> str:
"""Return type of LLM."""
Expand Down Expand Up @@ -179,25 +197,8 @@ async def _acreate_generate_stream(
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop

params = self._default_params

for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]

params["options"]["stop"] = stop
async for part in await self._async_client.generate(
model=params["model"],
prompt=prompt,
stream=True,
options=Options(**params["options"]),
keep_alive=params["keep_alive"],
format=params["format"],
**self._generate_params(prompt, stop=stop, **kwargs)
): # type: ignore
yield part # type: ignore

Expand All @@ -207,25 +208,8 @@ def _create_generate_stream(
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop

params = self._default_params

for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]

params["options"]["stop"] = stop
yield from self._client.generate(
model=params["model"],
prompt=prompt,
stream=True,
options=Options(**params["options"]),
keep_alive=params["keep_alive"],
format=params["format"],
**self._generate_params(prompt, stop=stop, **kwargs)
) # type: ignore

async def _astream_with_aggregation(
Expand Down

0 comments on commit 6836e5e

Please sign in to comment.