Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: increase max tokens and bump version #2460

Merged
merged 8 commits into from
Feb 27, 2025
2 changes: 1 addition & 1 deletion letta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.33"
__version__ = "0.6.34"

# import clients
from letta.client.client import LocalClient, RESTClient, create_client
Expand Down
14 changes: 12 additions & 2 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ def inner_step(
)

if current_total_tokens > summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window):
printd(
logger.warning(
f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window)}"
)

Expand All @@ -842,7 +842,7 @@ def inner_step(
self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this

else:
printd(
logger.info(
f"last response total_tokens ({current_total_tokens}) < {summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window)}"
)

Expand Down Expand Up @@ -892,6 +892,16 @@ def inner_step(
if is_context_overflow_error(e):
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)

# TODO: this is a patch to resolve immediate issues, should be removed once the summarizer is fixes
if self.agent_state.message_buffer_autoclear:
# no calling the summarizer in this case
logger.error(
f"step() failed with an exception that looks like a context window overflow, but message buffer is set to autoclear, so skipping: '{str(e)}'"
)
raise e

summarize_attempt_count += 1

if summarize_attempt_count <= summarizer_settings.max_summarizer_retries:
logger.warning(
f"context window exceeded with limit {self.agent_state.llm_config.context_window}, attempting to summarize ({summarize_attempt_count}/{summarizer_settings.max_summarizer_retries}"
Expand Down
59 changes: 58 additions & 1 deletion letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,65 @@ def create(
function_call = "required"

data = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, put_inner_thoughts_first=put_inner_thoughts_first
llm_config,
messages,
user_id,
functions,
function_call,
use_tool_naming,
put_inner_thoughts_first=put_inner_thoughts_first,
use_structured_output=True, # NOTE: turn on all the time for OpenAI API
)

if stream: # Client requested token streaming
data.stream = True
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
stream_interface, AgentRefreshStreamingInterface
), type(stream_interface)
response = openai_chat_completions_process_stream(
url=llm_config.model_endpoint,
api_key=api_key,
chat_completion_request=data,
stream_interface=stream_interface,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_start()
try:
response = openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=api_key,
chat_completion_request=data,
)
finally:
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_end()

if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)

return response

elif llm_config.model_endpoint_type == "xai":

api_key = model_settings.xai_api_key

if function_call is None and functions is not None and len(functions) > 0:
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
function_call = "required"

data = build_openai_chat_completions_request(
llm_config,
messages,
user_id,
functions,
function_call,
use_tool_naming,
put_inner_thoughts_first=put_inner_thoughts_first,
use_structured_output=False, # NOTE: not supported atm for xAI
)

if stream: # Client requested token streaming
data.stream = True
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
Expand Down
26 changes: 19 additions & 7 deletions letta/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from letta.schemas.message import MessageRole as _MessageRole
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
from letta.schemas.openai.chat_completion_request import Tool, ToolFunctionChoice, cast_message_to_subtype
from letta.schemas.openai.chat_completion_request import FunctionSchema, Tool, ToolFunctionChoice, cast_message_to_subtype
from letta.schemas.openai.chat_completion_response import (
ChatCompletionChunkResponse,
ChatCompletionResponse,
Expand Down Expand Up @@ -95,6 +95,7 @@ def build_openai_chat_completions_request(
function_call: Optional[str],
use_tool_naming: bool,
put_inner_thoughts_first: bool = True,
use_structured_output: bool = True,
) -> ChatCompletionRequest:
if functions and llm_config.put_inner_thoughts_in_kwargs:
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
Expand Down Expand Up @@ -157,6 +158,16 @@ def build_openai_chat_completions_request(
data.user = str(uuid.UUID(int=0))
data.model = "memgpt-openai"

if use_structured_output and data.tools is not None and len(data.tools) > 0:
# Convert to structured output style (which has 'strict' and no optionals)
for tool in data.tools:
try:
# tool["function"] = convert_to_structured_output(tool["function"])
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")

return data


Expand Down Expand Up @@ -455,11 +466,12 @@ def prepare_openai_payload(chat_completion_request: ChatCompletionRequest):
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")

if "tools" in data:
for tool in data["tools"]:
try:
tool["function"] = convert_to_structured_output(tool["function"])
except ValueError as e:
warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
# # NOTE: move this out to wherever the ChatCompletionRequest is created
# if "tools" in data:
# for tool in data["tools"]:
# try:
# tool["function"] = convert_to_structured_output(tool["function"])
# except ValueError as e:
# warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")

return data
5 changes: 5 additions & 0 deletions letta/orm/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def list(
join_model: Optional[Base] = None,
join_conditions: Optional[Union[Tuple, List]] = None,
identifier_keys: Optional[List[str]] = None,
identifier_id: Optional[str] = None,
**kwargs,
) -> List["SqlalchemyBase"]:
"""
Expand Down Expand Up @@ -147,6 +148,10 @@ def list(
if identifier_keys and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))

# given the identifier_id, we can find within the agents table any agents that have the identifier_id in their identity_ids
if identifier_id and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identifier_id)

# Apply filtering logic from kwargs
for key, value in kwargs.items():
if "." in key:
Expand Down
3 changes: 2 additions & 1 deletion letta/schemas/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class LLMConfig(BaseModel):
"together", # completions endpoint
"bedrock",
"deepseek",
"xai",
] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
Expand All @@ -56,7 +57,7 @@ class LLMConfig(BaseModel):
description="The temperature to use when generating text with the model. A higher temperature will result in more random text.",
)
max_tokens: Optional[int] = Field(
1024,
4096,
description="The maximum number of tokens to generate. If not set, the model will use its default value.",
)

Expand Down
67 changes: 67 additions & 0 deletions letta/schemas/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,63 @@ def get_model_context_window_size(self, model_name: str):
return None


class xAIProvider(OpenAIProvider):
"""https://docs.x.ai/docs/api-reference"""

name: str = "xai"
api_key: str = Field(..., description="API key for the xAI/Grok API.")
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")

def get_model_context_window_size(self, model_name: str) -> Optional[int]:
# xAI doesn't return context window in the model listing,
# so these are hardcoded from their website
if model_name == "grok-2-1212":
return 131072
else:
return None

def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list

response = openai_get_model_list(self.base_url, api_key=self.api_key)

if "data" in response:
data = response["data"]
else:
data = response

configs = []
for model in data:
assert "id" in model, f"xAI/Grok model missing 'id' field: {model}"
model_name = model["id"]

# In case xAI starts supporting it in the future:
if "context_length" in model:
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)

if not context_window_size:
warnings.warn(f"Couldn't find context window size for model {model_name}")
continue

configs.append(
LLMConfig(
model=model_name,
model_endpoint_type="xai",
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
)
)

return configs

def list_embedding_models(self) -> List[EmbeddingConfig]:
# No embeddings supported
return []


class DeepSeekProvider(OpenAIProvider):
"""
DeepSeek ChatCompletions API is similar to OpenAI's reasoning API,
Expand Down Expand Up @@ -456,6 +513,13 @@ def list_llm_models(self) -> List[LLMConfig]:
warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000")
model["context_window"] = 200000

max_tokens = 8192
if "claude-3-opus" in model["id"]:
max_tokens = 4096
if "claude-3-haiku" in model["id"]:
max_tokens = 4096
# TODO: set for 3-7 extended thinking mode

# We set this to false by default, because Anthropic can
# natively support <thinking> tags inside of content fields
# However, putting COT inside of tool calls can make it more
Expand All @@ -472,6 +536,7 @@ def list_llm_models(self) -> List[LLMConfig]:
context_window=model["context_window"],
handle=self.get_handle(model["id"]),
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
max_tokens=max_tokens,
)
)
return configs
Expand Down Expand Up @@ -811,6 +876,7 @@ def list_llm_models(self):
model_endpoint=self.base_url,
context_window=self.get_model_context_window(model),
handle=self.get_handle(model),
max_tokens=8192,
)
)
return configs
Expand Down Expand Up @@ -862,6 +928,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}",
context_window=context_length,
handle=self.get_handle(model),
max_tokens=8192,
)
)
return configs
Expand Down
32 changes: 6 additions & 26 deletions letta/server/rest_api/chat_completions_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ def _process_chunk_to_openai_style(self, chunk: ChatCompletionChunkResponse) ->
combined_args = "".join(self.current_function_arguments)
parsed_args = OptimisticJSONParser().parse(combined_args)

# TODO: Make this less brittle! This depends on `message` coming first!
# This is a heuristic we use to know if we're done with the `message` part of `send_message`
if len(parsed_args.keys()) > 1:
self._found_message_tool_kwarg = True
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
self.assistant_message_tool_kwarg
) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg):
self.current_json_parse_result = parsed_args
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
Expand All @@ -237,31 +237,11 @@ def _process_chunk_to_openai_style(self, chunk: ChatCompletionChunkResponse) ->
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(),
finish_reason="stop",
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
finish_reason=None,
)
],
)
else:
# If the parsed result is different
# This is an edge case we need to consider. E.g. if the last streamed token is '}', we shouldn't stream that out
if parsed_args != self.current_json_parse_result:
self.current_json_parse_result = parsed_args
# If we can see a "message" field, return it as partial content
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
finish_reason=None,
)
],
)

# If there's a finish reason, pass that along
if choice.finish_reason is not None:
Expand Down
Loading