Skip to content

Commit

Permalink
Merge pull request #29 from ivan-saorin/feature-max-tokens
Browse files Browse the repository at this point in the history
Add the ability to pass max_tokens into the configuration (required for some clients such as Anthropic)
  • Loading branch information
KennyVaneetvelde authored Nov 12, 2024
2 parents 615c5df + 7128d51 commit 54c94e0
Show file tree
Hide file tree
Showing 8 changed files with 3,804 additions and 3,512 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,4 @@ personal_scripts/
.output/

# Logs
debug.log
debug.log
16 changes: 15 additions & 1 deletion atomic-agents/atomic_agents/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

from instructor.dsl.partial import PartialBase
from jiter import from_json
from rich.console import Console

console = Console()


def model_from_chunks_patched(cls, json_chunks, **kwargs):
Expand Down Expand Up @@ -71,6 +74,10 @@ class BaseAgentConfig(BaseModel):
0,
description="Temperature for response generation, typically ranging from 0 to 1.",
)
max_tokens: Optional[int] = Field(
None,
description="Maximum number of token allowed in the response generation.",
)


class BaseAgent:
Expand All @@ -88,6 +95,7 @@ class BaseAgent:
memory (AgentMemory): Memory component for storing chat history.
system_prompt_generator (SystemPromptGenerator): Component for generating system prompts.
initial_memory (AgentMemory): Initial state of the memory.
max_tokens (int): Maximum number of tokens allowed in the response
"""

input_schema = BaseAgentInputSchema
Expand All @@ -109,6 +117,7 @@ def __init__(self, config: BaseAgentConfig):
self.initial_memory = self.memory.copy()
self.current_user_input = None
self.temperature = config.temperature
self.max_tokens = config.max_tokens

def reset_memory(self):
"""
Expand Down Expand Up @@ -136,12 +145,15 @@ def get_response(self, response_model=None) -> Type[BaseModel]:
"content": self.system_prompt_generator.generate_prompt(),
}
] + self.memory.get_history()

response = self.client.chat.completions.create(
model=self.model,
messages=messages,
model=self.model,
response_model=response_model,
temperature=self.temperature,
max_tokens=self.max_tokens,
)

return response

def run(self, user_input: Optional[Type[BaseIOSchema]] = None) -> Type[BaseIOSchema]:
Expand Down Expand Up @@ -189,6 +201,7 @@ async def get_response_async(self, response_model=None) -> Type[BaseModel]:
messages=messages,
response_model=response_model,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
return response

Expand Down Expand Up @@ -239,6 +252,7 @@ async def stream_response_async(self, user_input: Optional[Type[BaseIOSchema]] =
messages=messages,
response_model=self.output_schema,
temperature=self.temperature,
max_tokens=self.max_tokens,
stream=True,
)

Expand Down
3 changes: 3 additions & 0 deletions atomic-agents/tests/agents/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_initialization(agent, mock_instructor, mock_memory, mock_system_prompt_
assert agent.input_schema == BaseAgentInputSchema
assert agent.output_schema == BaseAgentOutputSchema
assert agent.temperature == 0
assert agent.max_tokens is None


def test_reset_memory(agent, mock_memory):
Expand All @@ -89,6 +90,7 @@ def test_get_response(agent, mock_instructor, mock_memory, mock_system_prompt_ge
messages=[{"role": "system", "content": "System prompt"}, {"role": "user", "content": "Hello"}],
response_model=BaseAgentOutputSchema,
temperature=0,
max_tokens=None,
)


Expand Down Expand Up @@ -207,6 +209,7 @@ async def test_get_response_async(agent, mock_instructor, mock_memory, mock_syst
messages=[{"role": "system", "content": "System prompt"}, {"role": "user", "content": "Hello"}],
response_model=BaseAgentOutputSchema,
temperature=0,
max_tokens=None,
)


Expand Down
3,208 changes: 1,737 additions & 1,471 deletions atomic-examples/quickstart/poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion atomic-examples/quickstart/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"
atomic-agents = {path = "../..", develop = true}
instructor = ">=0.2.1,<1.0.0"
instructor = ">=1.3.4,<2.0.0"
openai = ">=1.35.12,<2.0.0"
groq = ">=0.11.0,<1.0.0"
mistralai = ">=1.1.0,<2.0.0"
anthropic = ">=0.39.0,<1.0.0"


[build-system]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,26 @@

# Function to set up the client based on the chosen provider
def setup_client(provider):
if provider == "openai":
console.log(f"provider: {provider}")
if provider == "1" or provider == "openai":
from openai import OpenAI

api_key = os.getenv("OPENAI_API_KEY")
client = instructor.from_openai(OpenAI(api_key=api_key))
model = "gpt-4o-mini"
elif provider == "groq":
elif provider == "2" or provider == "anthropic":
from anthropic import Anthropic

api_key = os.getenv("ANTHROPIC_API_KEY")
client = instructor.from_anthropic(Anthropic(api_key=api_key))
model = "claude-3-5-haiku-20241022"
elif provider == "3" or provider == "groq":
from groq import Groq

api_key = os.getenv("GROQ_API_KEY")
client = instructor.from_groq(Groq(api_key=api_key), mode=instructor.Mode.JSON)
model = "mixtral-8x7b-32768"
elif provider == "ollama":
elif provider == "4" or provider == "ollama":
from openai import OpenAI as OllamaClient

client = instructor.from_openai(OllamaClient(base_url="http://localhost:11434/v1", api_key="ollama"))
Expand All @@ -45,20 +52,21 @@ def setup_client(provider):
return client, model


# Prompt the user to choose a provider
provider = console.input("[bold yellow]Choose a provider (openai/groq/ollama): [/bold yellow]").lower()
# Prompt the user to choose a provider from one in the list below.
providers_list = ["openai", "anthropic", "groq", "ollama"]
y = "bold yellow"
b = "bold blue"
g = "bold green"
provider_inner_str = f"{' / '.join(f'[[{g}]{i+1}[/{g}]]. [{b}]{provider}[/{b}]' for i, provider in enumerate(providers_list))}"
providers_str = f"[{y}]Choose a provider ({provider_inner_str}): [/{y}]"

provider = console.input(providers_str).lower()

# Set up the client and model based on the chosen provider
client, model = setup_client(provider)

# Agent setup with specified configuration
agent = BaseAgent(
config=BaseAgentConfig(
client=client,
model=model,
memory=memory,
)
)
agent = BaseAgent(config=BaseAgentConfig(client=client, model=model, memory=memory, max_tokens=2048))

# Generate the default system prompt for the agent
default_system_prompt = agent.system_prompt_generator.generate_prompt()
Expand Down
Loading

0 comments on commit 54c94e0

Please sign in to comment.