diff --git a/agentstack/_tools/__init__.py b/agentstack/_tools/__init__.py index a9382780..daf43bf2 100644 --- a/agentstack/_tools/__init__.py +++ b/agentstack/_tools/__init__.py @@ -13,6 +13,12 @@ TOOLS_CONFIG_FILENAME: str = 'config.json' +class ToolCategory(pydantic.BaseModel): + name: str + title: str # human readable title + description: str + + class ToolConfig(pydantic.BaseModel): """ This represents the configuration data for a tool. @@ -100,6 +106,19 @@ def module(self) -> ModuleType: ) +def get_all_tool_categories() -> list[ToolCategory]: + categories = [] + filename = TOOLS_DIR / 'categories.json' + data = open_json_file(filename) + for name, category in data.items(): + categories.append(ToolCategory(name=name, **category)) + return categories + + +def get_all_tool_category_names() -> list[str]: + return [category.name for category in get_all_tool_categories()] + + def get_all_tool_paths() -> list[Path]: """ Get all the paths to the tool configuration files. @@ -121,3 +140,8 @@ def get_all_tool_names() -> list[str]: def get_all_tools() -> list[ToolConfig]: return [ToolConfig.from_tool_name(path) for path in get_all_tool_names()] + + +def get_tool(name: str) -> ToolConfig: + return ToolConfig.from_tool_name(name) + diff --git a/agentstack/_tools/categories.json b/agentstack/_tools/categories.json new file mode 100644 index 00000000..d0ebebaa --- /dev/null +++ b/agentstack/_tools/categories.json @@ -0,0 +1,46 @@ +{ + "browsing": { + "title": "Browsing", + "description": "Tools that are used to browse the web." + }, + "code-execution": { + "title": "Code Execution", + "description": "Tools that are used to execute code." + }, + "computer-control": { + "title": "Computer Control", + "description": "Tools that are used to control a computer." + }, + "database": { + "title": "Database", + "description": "Tools that are used to interact with databases." + }, + "finance": { + "title": "Finance", + "description": "Tools that are used to interact with financial services." + }, + "image-analysis": { + "title": "Image Analysis", + "description": "Tools that are used to analyze images." + }, + "network-protocols": { + "title": "Network Protocols", + "description": "Tools that are used to interact with network protocols." + }, + "search": { + "title": "Search", + "description": "Tools that are used to search for information." + }, + "storage": { + "title": "Storage", + "description": "Tools that are used to interact with storage." + }, + "unified-apis": { + "title": "Unified APIs", + "description": "Tools that provide a unified API for interacting with multiple services." + }, + "web-retrieval": { + "title": "Web Retrieval", + "description": "Tools that are used to retrieve information from the web." + } +} diff --git a/agentstack/_tools/payman/config.json b/agentstack/_tools/payman/config.json index 10eee8d9..9e31ca54 100644 --- a/agentstack/_tools/payman/config.json +++ b/agentstack/_tools/payman/config.json @@ -1,6 +1,6 @@ { "name": "payman", - "category": "financial-infra", + "category": "finance", "tools": [ "send_payment", "search_available_payees", diff --git a/agentstack/_tools/stripe/config.json b/agentstack/_tools/stripe/config.json index 89b18366..8cdfe3da 100644 --- a/agentstack/_tools/stripe/config.json +++ b/agentstack/_tools/stripe/config.json @@ -1,7 +1,7 @@ { "name": "stripe", "url": "https://github.com/stripe/agent-toolkit", - "category": "application-specific", + "category": "finance", "env": { "STRIPE_SECRET_KEY": null }, diff --git a/agentstack/_tools/weaviate/config.json b/agentstack/_tools/weaviate/config.json index 1323a10f..7e7bb5ce 100644 --- a/agentstack/_tools/weaviate/config.json +++ b/agentstack/_tools/weaviate/config.json @@ -1,7 +1,7 @@ { "name": "weaviate", "url": "https://github.com/weaviate/weaviate-python-client", - "category": "vector-store", + "category": "database", "env": { "WEAVIATE_URL": null, "WEAVIATE_API_KEY": null, diff --git a/agentstack/cli/__init__.py b/agentstack/cli/__init__.py index fba3c0c2..457606c7 100644 --- a/agentstack/cli/__init__.py +++ b/agentstack/cli/__init__.py @@ -1,9 +1,7 @@ -from .cli import configure_default_model, welcome_message, get_validated_input, parse_insertion_point +from .cli import LOGO, configure_default_model, welcome_message, get_validated_input, parse_insertion_point from .init import init_project -from .wizard import run_wizard from .run import run_project from .tools import list_tools, add_tool, remove_tool from .tasks import add_task from .agents import add_agent from .templates import insert_template, export_template - diff --git a/agentstack/cli/cli.py b/agentstack/cli/cli.py index a40e83ff..68a9620c 100644 --- a/agentstack/cli/cli.py +++ b/agentstack/cli/cli.py @@ -1,34 +1,31 @@ from typing import Optional import os, sys -from art import text2art import inquirer from agentstack import conf, log from agentstack.conf import ConfigFile from agentstack.exceptions import ValidationError from agentstack.utils import validator_not_empty, is_snake_case +from agentstack import providers from agentstack.generation import InsertionPoint -PREFERRED_MODELS = [ - 'groq/deepseek-r1-distill-llama-70b', - 'deepseek/deepseek-chat', - 'deepseek/deepseek-coder', - 'deepseek/deepseek-reasoner', - 'openai/gpt-4o', - 'anthropic/claude-3-5-sonnet', - 'openai/o1-preview', - 'openai/gpt-4-turbo', - 'anthropic/claude-3-opus', -] +LOGO = """\ + ___ ___ ___ ___ ___ ___ ___ ___ ___ ___ + /\ \ /\ \ /\ \ /\__\ /\ \ /\ \ /\ \ /\ \ /\ \ /\__\ + /::\ \ /::\ \ /::\ \ /:| _|_ \:\ \ /::\ \ \:\ \ /::\ \ /::\ \ /:/ _/_ + /::\:\__\ /:/\:\__\ /::\:\__\ /::|/\__\ /::\__\ /\:\:\__\ /::\__\ /::\:\__\ /:/\:\__\ /::-"\__\\ + \/\::/ / \:\:\/__/ \:\:\/ / \/|::/ / /:/\/__/ \:\:\/__/ /:/\/__/ \/\::/ / \:\ \/__/ \;:;-",-" + /:/ / \::/ / \:\/ / |:/ / \/__/ \::/ / \/__/ /:/ / \:\__\ |:| | + \/__/ \/__/ \/__/ \/__/ \/__/ \/__/ \/__/ \|__| +""" def welcome_message(): - title = text2art("AgentStack", font="smisome1") tagline = "The easiest way to build a robust agent application!" border = "-" * len(tagline) # Print the welcome message with ASCII art - log.info(title) + log.info(LOGO) log.info(border) log.info(tagline) log.info(border) @@ -45,7 +42,7 @@ def configure_default_model(): other_msg = "Other (enter a model name)" model = inquirer.list_input( message="Which model would you like to use?", - choices=PREFERRED_MODELS + [other_msg], + choices=providers.get_preferred_model_ids() + [other_msg], ) if model == other_msg: # If the user selects "Other", prompt for a model name diff --git a/agentstack/cli/init.py b/agentstack/cli/init.py index 758db9b8..83a8a4ef 100644 --- a/agentstack/cli/init.py +++ b/agentstack/cli/init.py @@ -14,7 +14,6 @@ from agentstack.templates import get_all_templates, TemplateConfig from agentstack.cli import welcome_message -from agentstack.cli.wizard import run_wizard from agentstack.cli.templates import insert_template @@ -67,7 +66,7 @@ def init_project( slug_name: Optional[str] = None, template: Optional[str] = None, framework: Optional[str] = None, - use_wizard: bool = False, + template_data: Optional[TemplateConfig] = None, ): """ Initialize a new project in the current directory. @@ -78,9 +77,6 @@ def init_project( - insert Tasks, Agents and Tools """ # TODO prevent the user from passing the --path argument to init - if template and use_wizard: - raise Exception("Template and wizard flags cannot be used together") - require_uv() welcome_message() @@ -102,16 +98,14 @@ def init_project( if os.path.exists(conf.PATH): # cookiecutter requires the directory to not exist raise Exception(f"Directory already exists: {conf.PATH}") - if use_wizard: - log.debug("Initializing new project with wizard.") - template_data = run_wizard(slug_name) - elif template: + if not template_data and template: log.debug(f"Initializing new project with template: {template}") template_data = TemplateConfig.from_user_input(template) - else: + elif not template_data: log.debug("Initializing new project with template selection.") template_data = select_template(slug_name, framework) + assert template_data # appease type checker log.notify("🦾 Creating a new AgentStack project...") log.info(f"Using project directory: {conf.PATH.absolute()}") diff --git a/agentstack/cli/wizard.py b/agentstack/cli/wizard.py index 4ad24802..d114f4a3 100644 --- a/agentstack/cli/wizard.py +++ b/agentstack/cli/wizard.py @@ -1,261 +1,1096 @@ -from typing import Optional -import os +import os, sys +import curses import time -import inquirer -import webbrowser -from art import text2art -from agentstack import log -from agentstack.frameworks import SUPPORTED_FRAMEWORKS -from agentstack.utils import open_json_file, is_snake_case -from agentstack.cli import welcome_message, get_validated_input -from agentstack.cli.cli import PREFERRED_MODELS -from agentstack._tools import get_all_tools, get_all_tool_names +import math +from random import randint +from dataclasses import dataclass +from typing import Optional, Any, Union, TypedDict +from enum import Enum +from pathlib import Path +from abc import abstractmethod, ABCMeta + +from agentstack import conf, log +from agentstack.utils import is_snake_case +from agentstack.tui import * +from agentstack import providers +from agentstack import frameworks +from agentstack._tools import ( + get_all_tools, + get_tool, + get_all_tool_categories, + get_all_tool_category_names, +) from agentstack.templates import TemplateConfig +from agentstack.cli import LOGO, init_project -class WizardData(dict): - def to_template_config(self) -> TemplateConfig: - agents = [] - for agent in self['design']['agents']: - agents.append(TemplateConfig.Agent(**{ - 'name': agent['name'], - 'role': agent['role'], - 'goal': agent['goal'], - 'backstory': agent['backstory'], - 'llm': agent['model'], - })) - - tasks = [] - for task in self['design']['tasks']: - tasks.append(TemplateConfig.Task(**{ - 'name': task['name'], - 'description': task['description'], - 'expected_output': task['expected_output'], - 'agent': task['agent'], - })) - - tools = [] - for tool in self['tools']: - tools.append(TemplateConfig.Tool(**{ - 'name': tool, - 'agents': [agent.name for agent in agents], # all agents - })) - - return TemplateConfig( - name=self['project']['name'], - description=self['project']['description'], - template_version=4, - framework=self['framework'], - method='sequential', - manager_agent=None, - agents=agents, - tasks=tasks, - tools=tools, - graph=[], - inputs={}, +COLOR_BORDER = Color(90) +COLOR_MAIN = Color(220) +COLOR_TITLE = Color(220, 100, 40, reversed=True) +COLOR_ERROR = Color(0, 70) +COLOR_BANNER = Color(80, 80, 80) +COLOR_FORM = Color(300) +COLOR_FORM_BORDER = Color(300, 80) +COLOR_BUTTON = Color(300, 100, 80, reversed=True) +COLOR_FIELD_BG = Color(300, 50, 50, reversed=True) +COLOR_FIELD_BORDER = Color(300, 100, 50) +COLOR_FIELD_ACTIVE = Color(300, 80) + + +class FieldColors(TypedDict): + color: Color + border: Color + active: Color + + +FIELD_COLORS: FieldColors = { + 'color': COLOR_FIELD_BG, + 'border': COLOR_FIELD_BORDER, + 'active': COLOR_FIELD_ACTIVE, +} + + +class LogoElement(Text): + h_align = ALIGN_CENTER + + def __init__(self, coords: tuple[int, int], dims: tuple[int, int]): + super().__init__(coords, dims) + self.color = COLOR_MAIN + self.value = LOGO + self.stars = [(3, 1), (25, 5), (34, 1), (52, 2), (79, 3), (97, 1)] + self._star_colors = {} + content_width = len(LOGO.split('\n')[0]) + self.left_offset = max(0, round((self.width - content_width) / 2)) + + def _get_star_color(self, index: int) -> Color: + if index not in self._star_colors: + self._star_colors[index] = ColorAnimation( + Color(randint(0, 150)), + Color(randint(200, 360)), + duration=2.0, + loop=True, + ) + return self._star_colors[index] + + def render(self) -> None: + super().render() + for i, (x, y) in enumerate(self.stars): + try: + self.grid.addch(y, self.left_offset + x, '*', self._get_star_color(i).to_curses()) + except curses.error: + pass # overflow + + +class StarBox(Box): + """Renders random stars that animate down the page in the background of the box.""" + + def __init__(self, coords: tuple[int, int], dims: tuple[int, int], **kwargs): + super().__init__(coords, dims, **kwargs) + self.stars = [(randint(0, self.width - 1), randint(0, self.height - 1)) for _ in range(11)] + self.star_colors = [ + ColorAnimation( + Color(randint(0, 150)), + Color(randint(200, 360)), + duration=2.0, + loop=False, + ) + for _ in range(11) + ] + self.star_y = [randint(0, self.height - 1) for _ in range(11)] + self.star_x = [randint(0, self.width - 1) for _ in range(11)] + self.star_speed = 0.001 + self.star_timer = 0.0 + self.star_index = 0 + + def render(self) -> None: + self.grid.clear() + for i in range(len(self.stars)): + if self.star_y[i] > 0: # undraw previous star position + self.grid.addch(self.star_y[i] - 1, self.star_x[i], ' ') + else: # previous star was at bottom of screen + self.grid.addch(self.height - 1, self.star_x[i], ' ') + + if self.star_y[i] < self.height: + self.grid.addch(self.star_y[i], self.star_x[i], '*', self.star_colors[i].to_curses()) + self.star_y[i] += 1 + else: + self.star_y[i] = 0 + self.star_x[i] = randint(0, self.width - 1) + super().render() + + +class HelpText(Text): + def __init__(self, coords: tuple[int, int], dims: tuple[int, int]) -> None: + super().__init__(coords, dims) + self.color = Color(0, 0, 70) + self.value = " | ".join( + [ + "select [tab]", + "navigate [up / down]", + "confirm [space / enter]", + "[q]uit", + ] + ) + if conf.DEBUG: + self.value += " | [d]ebug" + + +class WizardView(View): + app: 'WizardApp' + + +class BannerView(WizardView): + name = "banner" + title = "Welcome to AgentStack" + sparkle = "The easiest way to build a robust agent application." + subtitle = "Let's get started!" + + def _get_color(self) -> Color: + return ColorAnimation( + start=COLOR_BANNER.sat(0).val(0), + end=COLOR_BANNER, + duration=0.5, ) + def layout(self) -> list[Renderable]: + buttons_conf: dict[str, Callable] = {} + + if not self.app.state.project: + buttons_conf["Create Project"] = lambda: self.app.load('project', workflow='project') + else: + buttons_conf["New Agent"] = lambda: self.app.load('agent', workflow='agent') -def run_wizard(slug_name: str) -> TemplateConfig: - project_details = ask_project_details(slug_name) - welcome_message() - framework = ask_framework() - design = ask_design() - tools = ask_tools() - - wizard_data = WizardData({ - 'project': project_details, - 'framework': framework, - 'design': design, - 'tools': tools, - }) - return wizard_data.to_template_config() + if len(self.app.state.agents): + buttons_conf["New Task"] = lambda: self.app.load('task', workflow='task') + buttons_conf["Add Tools"] = lambda: self.app.load('tool_agent_selection', workflow='tool') + if self.app.state.project: + buttons_conf["Finish"] = lambda: self.app.finish() -def ask_framework() -> str: - framework = inquirer.list_input( - message="What agent framework do you want to use?", - choices=SUPPORTED_FRAMEWORKS, - ) - # - # if framework == "Learn what these are (link)": - # webbrowser.open("https://youtu.be/xvFZjo5PgG0") - # framework = inquirer.list_input( - # message="What agent framework do you want to use?", - # choices=["CrewAI", "Autogen", "LiteLLM"], - # ) - # - # while framework in ['Autogen', 'LiteLLM']: - # print(f"{framework} support coming soon!!") - # framework = inquirer.list_input( - # message="What agent framework do you want to use?", - # choices=["CrewAI", "Autogen", "LiteLLM"], - # ) - - #log.success("Congrats! Your project is ready to go! Quickly add features now or skip to do it later.\n\n") - return framework - - -def ask_agent_details(): - agent = {} - - agent['name'] = get_validated_input( - "What's the name of this agent? (snake_case)", min_length=3, snake_case=True - ) - - agent['role'] = get_validated_input("What role does this agent have?", min_length=3) + buttons: list[Button] = [] + num_buttons = len(buttons_conf) + button_width = min(round(self.width / 2), round(self.width / num_buttons) - 2) + left_offset = round((self.width - (num_buttons * button_width)) / 2) if num_buttons == 1 else 2 - agent['goal'] = get_validated_input("What is the goal of the agent?", min_length=10) - - agent['backstory'] = get_validated_input("Give your agent a backstory", min_length=10) + for title, action in buttons_conf.items(): + buttons.append( + Button( + (self.height - 5, left_offset), + (3, button_width), + title, + color=COLOR_BUTTON, + on_confirm=action, + ) + ) + left_offset += button_width + 1 - agent['model'] = inquirer.list_input( - message="What LLM should this agent use?", choices=PREFERRED_MODELS, default=PREFERRED_MODELS[0] - ) - - return agent - - -def ask_task_details(agents: list[dict]) -> dict: - task = {} - - task['name'] = get_validated_input( - "What's the name of this task? (snake_case)", min_length=3, snake_case=True - ) - - task['description'] = get_validated_input("Describe the task in more detail", min_length=10) - - task['expected_output'] = get_validated_input( - "What do you expect the result to look like? (ex: A 5 bullet point summary of the email)", - min_length=10, - ) - - task['agent'] = inquirer.list_input( - message="Which agent should be assigned this task?", - choices=[a['name'] for a in agents], - ) - - return task - - -def ask_design() -> dict: - use_wizard = inquirer.confirm( - message="Would you like to use the CLI wizard to set up agents and tasks?", - ) - - if not use_wizard: - return {'agents': [], 'tasks': []} - - os.system("cls" if os.name == "nt" else "clear") - - title = text2art("AgentWizard", font="shimrod") - - print(title) - - print(""" -🪄 welcome to the agent builder wizard!! 🪄 - -First we need to create the agents that will work together to accomplish tasks: - """) - make_agent = True - agents = [] - while make_agent: - print('---') - print(f"Agent #{len(agents)+1}") - agent = None - agent = ask_agent_details() - agents.append(agent) - make_agent = inquirer.confirm(message="Create another agent?") - - print('') - for x in range(3): - time.sleep(0.3) - print('.') - print('Boom! We made some agents (ノ>ω<)ノ :。・:*:・゚’★,。・:*:・゚’☆') - time.sleep(0.5) - print('') - print('Now lets make some tasks for the agents to accomplish!') - print('') - - make_task = True - tasks = [] - while make_task: - print('---') - print(f"Task #{len(tasks) + 1}") - task = ask_task_details(agents) - tasks.append(task) - make_task = inquirer.confirm(message="Create another task?") + return [ + StarBox( + (0, 0), + (self.height - 1, self.width), + color=COLOR_BORDER, + modules=[ + LogoElement((1, 1), (7, self.width - 2)), + Box( + (round(self.height / 3), round(self.width / 4)), + (9, round(self.width / 2)), + color=COLOR_BANNER, + modules=[ + BoldText( + (1, 2), + (2, round(self.width / 2) - 3), + color=self._get_color(), + value=self.title, + ), + WrappedText( + (3, 2), + (3, round(self.width / 2) - 3), + color=self._get_color(), + value=self.sparkle, + ), + WrappedText( + (6, 2), + (2, round(self.width / 2) - 3), + color=self._get_color(), + value=self.subtitle, + ), + ], + ), + *buttons, + ], + ), + HelpText((self.height - 1, 0), (1, self.width)), + ] + + +class FormView(WizardView, metaclass=ABCMeta): + title: str + error_message: Node + + def __init__(self, app: 'App'): + super().__init__(app) + self.error_message = Node() + + def submit(self): + pass - print('') - for x in range(3): - time.sleep(0.3) - print('.') - print('Let there be tasks (ノ ˘_˘)ノ ζ|||ζ ζ|||ζ ζ|||ζ') + def error(self, message: str): + self.error_message.value = message - return {'tasks': tasks, 'agents': agents} + @abstractmethod + def form(self) -> list[Renderable]: ... + def layout(self) -> list[Renderable]: + return [ + Box( + (0, 0), + (self.height - 1, self.width), + color=COLOR_BORDER, + modules=[ + LogoElement((1, 1), (7, self.width - 2)), + Title((9, 1), (1, self.width - 2), color=COLOR_TITLE, value=self.title), + Title( + (self.height - 5, round(self.width / 3)), + (3, round(self.width / 3)), + color=COLOR_ERROR, + value=self.error_message, + ), + *self.form(), + Button( + (self.height - 5, self.width - 17), + (3, 15), + "Next", + color=COLOR_BUTTON, + on_confirm=self.submit, + ), + ], + ), + HelpText((self.height - 1, 0), (1, self.width)), + ] + + +class AgentSelectionView(FormView, metaclass=ABCMeta): + title = "Select an Agent" -def ask_tools() -> list: - use_tools = inquirer.confirm( - message="Do you want to add agent tools now? (you can do this later with `agentstack tools add `)", - ) + def __init__(self, app: 'App'): + super().__init__(app) + self.agent_key = Node() + self.agent_name = Node() + self.agent_llm = Node() + self.agent_description = Node() - if not use_tools: - return [] + def set_agent_selection(self, index: int, value: str): + agent_data = self.app.state.agents[value] + self.agent_name.value = value + self.agent_llm.value = agent_data['llm'] + self.agent_description.value = agent_data['role'] - tools_to_add = [] + def set_agent_choice(self, index: int, value: str): + self.agent_key.value = value - adding_tools = True - tool_configs = get_all_tools() + def get_agent_options(self) -> list[str]: + return list(self.app.state.agents.keys()) - while adding_tools: - tool_categories = [] - for tool_config in tool_configs: - if tool_config.category not in tool_categories: - tool_categories.append(tool_config.category) - - tool_type = inquirer.list_input( - message="What category tool do you want to add?", - choices=tool_categories + ["~~ Stop adding tools ~~"], + @abstractmethod + def submit(self): ... + + def form(self) -> list[Renderable]: + return [ + RadioSelect( + (10, 1), + (self.height - 15, round(self.width / 2) - 2), + options=self.get_agent_options(), + color=COLOR_FORM_BORDER, + highlight=ColorAnimation(COLOR_BUTTON.sat(0), COLOR_BUTTON, duration=0.2), + on_change=self.set_agent_selection, + on_select=self.set_agent_choice, + ), + Box( + (10, round(self.width / 2) + 1), + (self.height - 15, round(self.width / 2) - 2), + color=COLOR_FORM_BORDER, + modules=[ + ASCIIText( + (1, 3), + (4, round(self.width / 2) - 10), + color=COLOR_FORM.sat(40), + value=self.agent_name, + ), + BoldText((5, 3), (1, round(self.width / 2) - 10), color=COLOR_FORM, value=self.agent_llm), + WrappedText( + (7, 3), + (min(5, self.height - 24), round(self.width / 2) - 10), + color=COLOR_FORM.sat(50), + value=self.agent_description, + ), + ], + ), + ] + + +class ProjectView(FormView): + title = "Define your Project" + + def __init__(self, app: 'App'): + super().__init__(app) + self.project_name = Node() + self.project_description = Node() + + def submit(self): + if not self.project_name.value: + self.error("Name is required.") + return + + if not is_snake_case(self.project_name.value): + self.error("Name must be in snake_case.") + return + + if os.path.exists(conf.PATH / self.project_name.value): + self.error(f"Directory '{self.project_name.value}' already exists.") + return + + self.app.state.create_project( + name=self.project_name.value, + description=self.project_description.value, ) + self.app.advance() + + def form(self) -> list[Renderable]: + return [ + Text((11, 2), (1, 12), color=COLOR_FORM, value="Name"), + TextInput( + (11, 14), + (2, self.width - 15), + self.project_name, + placeholder="This will be used to create a new directory. Must be snake_case.", + **FIELD_COLORS, + ), + Text((13, 2), (1, 12), color=COLOR_FORM, value="Description"), + TextInput( + (13, 14), + (5, self.width - 15), + self.project_description, + placeholder="Describe what you project will do.", + **FIELD_COLORS, + ), + ] + + +class FrameworkView(FormView): + title = "Select a Framework" + + def __init__(self, app: 'App'): + super().__init__(app) + self.framework_key = Node() + self.framework_name = Node() + self.framework_description = Node() + self.framework_options = { + key: frameworks.get_framework_info(key) for key in frameworks.SUPPORTED_FRAMEWORKS + } + + def set_framework_selection(self, index: int, value: str): + """Update the content of the framework info box.""" + data = self.framework_options[value] + self.framework_name.value = data['name'] + self.framework_description.value = data['description'] + + def set_framework_choice(self, index: int, value: str): + """Save the selection.""" + self.framework_key.value = value + + def submit(self): + if not self.framework_key.value: + self.error("Framework is required.") + return + + self.app.state.update_framework(self.framework_key.value) + self.app.advance() + + def form(self) -> list[Renderable]: + return [ + RadioSelect( + (10, 1), + (self.height - 15, round(self.width / 2) - 2), + options=list(self.framework_options.keys()), + color=COLOR_FORM_BORDER, + highlight=ColorAnimation(COLOR_BUTTON.sat(0), COLOR_BUTTON, duration=0.2), + on_change=self.set_framework_selection, + on_select=self.set_framework_choice, + ), + Box( + (10, round(self.width / 2)), + (self.height - 15, round(self.width / 2) - 2), + color=COLOR_FORM_BORDER, + modules=[ + ASCIIText( + (1, 3), + (4, round(self.width / 2) - 10), + color=COLOR_FORM.sat(40), + value=self.framework_name, + ), + BoldText( + (5, 3), (1, round(self.width / 2) - 10), color=COLOR_FORM, value=self.framework_name + ), + WrappedText( + (7, 3), + (min(5, self.height - 24), round(self.width / 2) - 10), + color=COLOR_FORM.sat(50), + value=self.framework_description, + ), + ], + ), + ] + + +class AfterProjectView(BannerView): + title = "We've got a project!" + sparkle = "(づ ◕‿◕ )づ *゚・:*:・゚’★,。・:*:・゚’☆" + subtitle = "Now, add an Agent to handle your tasks!" + + +class AgentView(FormView): + title = "Define your Agent" + + def __init__(self, app: 'App'): + super().__init__(app) + self.agent_name = Node() + self.agent_role = Node() + self.agent_goal = Node() + self.agent_backstory = Node() + + def submit(self): + agent_name = self.agent_name.value + if not agent_name: + self.error("Name is required.") + return + + if not is_snake_case(agent_name): + self.error("Name must be in snake_case.") + return + + if agent_name in self.app.state.agents.keys(): + self.error("Agent name must be unique.") + return + + if agent_name in self.app.state.tasks.keys(): + self.error("Agent name cannot match a task name.") + return + + self.app.state.create_agent( + name=agent_name, + role=self.agent_role.value, + goal=self.agent_goal.value, + backstory=self.agent_backstory.value, + ) + self.app.advance() + + def form(self) -> list[Renderable]: + large_field_height = min(5, round((self.height - 17) / 3)) + return [ + Text((11, 2), (1, 12), color=COLOR_FORM, value="Name"), + TextInput( + (11, 14), + (2, self.width - 16), + self.agent_name, + placeholder="A unique name for this agent. Must be snake_case.", + **FIELD_COLORS, + ), + Text((13, 2), (1, 12), color=COLOR_FORM, value="Role"), + TextInput( + (13, 14), + (large_field_height, self.width - 16), + self.agent_role, + placeholder="A prompt to the agent that describes the role it takes in your project.", + ** FIELD_COLORS, + ), + Text((13 + large_field_height, 2), (1, 12), color=COLOR_FORM, value="Goal"), + TextInput( + (13 + large_field_height, 14), + (large_field_height, self.width - 16), + self.agent_goal, + placeholder="A prompt to the agent that describes the goal it is trying to achieve.", + **FIELD_COLORS, + ), + Text((13 + (large_field_height * 2), 2), (1, 12), color=COLOR_FORM, value="Backstory"), + TextInput( + (13 + (large_field_height * 2), 14), + (large_field_height, self.width - 16), + self.agent_backstory, + placeholder="A prompt to the agent that describes the backstory of it's purpose.", + **FIELD_COLORS, + ), + ] + + +class ModelView(FormView): + title = "Select a Model" + + def __init__(self, app: 'App'): + super().__init__(app) + self.MODEL_CHOICES = providers.get_preferred_models() + self.model_choice = Node() + self.model_logo = Node() + self.model_name = Node() + self.model_description = Node() + + def set_model_selection(self, index: int, value: str): + """Update the content of the model info box.""" + model = self.MODEL_CHOICES[index] + self.model_logo.value = model.host + self.model_name.value = model.name + self.model_description.value = model.description + + def set_model_choice(self, index: int, value: str): + """Save the selection.""" + # list in UI shows the actual key + self.model_choice.value = value + + def get_model_options(self): + return providers.get_preferred_model_ids() + + def submit(self): + if not self.model_choice.value: + self.error("Model is required.") + return + + self.app.state.update_active_agent(llm=self.model_choice.value) + self.app.advance() - tools_in_cat = [] - for tool_config in tool_configs: - if tool_config.category == tool_type: - tools_in_cat.append(tool_config) - - tool_selection = inquirer.list_input( - message="Select your tool", - choices=[f"{t.name} - {t.url}" for t in tools_in_cat if t not in tools_to_add], + def form(self) -> list[Renderable]: + return [ + RadioSelect( + (10, 1), + (self.height - 15, round(self.width / 2) - 2), + options=self.get_model_options(), + color=COLOR_FORM_BORDER, + highlight=ColorAnimation(COLOR_BUTTON.sat(0), COLOR_BUTTON, duration=0.2), + on_change=self.set_model_selection, + on_select=self.set_model_choice, + ), + Box( + (10, round(self.width / 2)), + (self.height - 15, round(self.width / 2) - 2), + color=COLOR_FORM_BORDER, + modules=[ + ASCIIText( + (1, 3), + (4, round(self.width / 2) - 10), + color=COLOR_FORM.sat(40), + value=self.model_logo, + ), + BoldText( + (5, 3), (1, round(self.width / 2) - 10), color=COLOR_FORM, value=self.model_name + ), + WrappedText( + (7, 3), + (min(5, self.height - 24), round(self.width / 2) - 10), + color=COLOR_FORM.sat(50), + value=self.model_description, + ), + ], + ), + ] + + +class ToolCategoryView(FormView): + title = "Select a Tool Category" + + def __init__(self, app: 'App'): + super().__init__(app) + self.tool_category_key = Node() + self.tool_category_name = Node() + self.tool_category_description = Node() + + def set_tool_category_selection(self, index: int, value: str): + tool_category = None + for _tool_category in get_all_tool_categories(): + if _tool_category.name == value: # search by name + tool_category = _tool_category + break + + if tool_category: + self.tool_category_name.value = tool_category.title + self.tool_category_description.value = tool_category.description + + def set_tool_category_choice(self, index: int, value: str): + self.tool_category_key.value = value + + def submit(self): + if not self.tool_category_key.value: + self.error("Tool category is required.") + return + + self.app.state.tool_category = self.tool_category_key.value + self.app.advance() + + def skip(self): + self.app.advance(steps=2) + + def form(self) -> list[Renderable]: + return [ + RadioSelect( + (10, 1), + (self.height - 15, round(self.width / 2) - 2), + options=get_all_tool_category_names(), + color=COLOR_FORM_BORDER, + highlight=ColorAnimation(COLOR_BUTTON.sat(0), COLOR_BUTTON, duration=0.2), + on_change=self.set_tool_category_selection, + on_select=self.set_tool_category_choice, + ), + Box( + (10, round(self.width / 2) + 1), + (self.height - 15, round(self.width / 2) - 2), + color=COLOR_FORM_BORDER, + modules=[ + ASCIIText( + (1, 3), + (4, round(self.width / 2) - 10), + color=COLOR_FORM.sat(40), + value=self.tool_category_name, + ), + BoldText( + (5, 3), + (1, round(self.width / 2) - 10), + color=COLOR_FORM, + value=self.tool_category_name, + ), + WrappedText( + (7, 3), + (min(5, self.height - 24), round(self.width / 2) - 10), + color=COLOR_FORM.sat(50), + value=self.tool_category_description, + ), + ], + ), + Button( + (self.height - 5, 2), + (3, 15), + "Skip", + color=COLOR_BUTTON, + on_confirm=self.skip, + ), + ] + + +class ToolView(FormView): + title = "Select a Tool" + + def __init__(self, app: 'App'): + super().__init__(app) + self.tool_key = Node() + self.tool_name = Node() + self.tool_description = Node() + + @property + def category(self) -> str: + return self.app.state.tool_category + + def set_tool_selection(self, index: int, value: str): + tool_config = get_tool(value) + self.tool_name.value = tool_config.name + self.tool_description.value = tool_config.cta + + def set_tool_choice(self, index: int, value: str): + self.tool_key.value = value + + def get_tool_options(self) -> list[str]: + return sorted([tool.name for tool in get_all_tools() if tool.category == self.category]) + + def submit(self): + if not self.tool_key.value: + self.error("Tool is required.") + return + + self.app.state.update_active_agent_tools(self.tool_key.value) + self.app.advance() + + def back(self): + self.app.back() + + def form(self) -> list[Renderable]: + return [ + RadioSelect( + (10, 1), + (self.height - 15, round(self.width / 2) - 2), + options=self.get_tool_options(), + color=COLOR_FORM_BORDER, + highlight=ColorAnimation(COLOR_BUTTON.sat(0), COLOR_BUTTON, duration=0.2), + on_change=self.set_tool_selection, + on_select=self.set_tool_choice, + ), + Box( + (10, round(self.width / 2) + 1), + (self.height - 15, round(self.width / 2) - 2), + color=COLOR_FORM_BORDER, + modules=[ + ASCIIText( + (1, 3), + (4, round(self.width / 2) - 10), + color=COLOR_FORM.sat(40), + value=self.tool_name, + ), + BoldText((5, 3), (1, round(self.width / 2) - 10), color=COLOR_FORM, value=self.tool_name), + WrappedText( + (7, 3), + (min(5, self.height - 24), round(self.width / 2) - 10), + color=COLOR_FORM.sat(50), + value=self.tool_description, + ), + ], + ), + Button( + (self.height - 5, 2), + (3, 15), + "Back", + color=COLOR_BUTTON, + on_confirm=self.back, + ), + ] + + +class ToolAgentSelectionView(AgentSelectionView): + title = "Select an Agent for your Tool" + + def submit(self): + if not self.agent_key.value: + self.error("Agent is required.") + return + + self.app.state.active_agent = self.agent_key.value + self.app.advance() + + +class AfterAgentView(BannerView): + title = "Boom! We made some agents." + sparkle = "(ノ>ω<)ノ :。・:*:・゚’★,。・:*:・゚’☆" + subtitle = "Now lets make some tasks for the agents to accomplish!" + + +class TaskView(FormView): + title = "Define your Task" + + def __init__(self, app: 'App'): + super().__init__(app) + self.task_name = Node() + self.task_description = Node() + self.expected_output = Node() + + def submit(self): + task_name = self.task_name.value + if not self.task_name: + self.error("Task name is required.") + return + + if not is_snake_case(task_name): + self.error("Task name must be in snake_case.") + return + + if task_name in self.app.state.tasks.keys(): + self.error("Task name must be unique.") + return + + if task_name in self.app.state.agents.keys(): + self.error("Task name cannot match an agent name.") + return + + self.app.state.create_task( + name=task_name, + description=self.task_description.value, + expected_output=self.expected_output.value, ) + self.app.advance() + + def form(self) -> list[Renderable]: + large_field_height = min(5, round((self.height - 17) / 3)) + return [ + Text((11, 2), (1, 12), color=COLOR_FORM, value="Name"), + TextInput( + (11, 14), + (2, self.width - 16), + self.task_name, + placeholder="A unique name for this task. Must be snake_case.", + **FIELD_COLORS, + ), + Text((13, 2), (1, 12), color=COLOR_FORM, value="Description"), + TextInput( + (13, 14), + (large_field_height, self.width - 16), + self.task_description, + placeholder="A prompt for this task that describes what should be done.", + **FIELD_COLORS, + ), + Text((13 + large_field_height, 2), (2, 12), color=COLOR_FORM, value="Expected\nOutput"), + TextInput( + (13 + large_field_height, 14), + (large_field_height, self.width - 16), + self.expected_output, + placeholder="A prompt for this task that describes what the output should look like.", + **FIELD_COLORS, + ), + ] - tools_to_add.append(tool_selection.split(' - ')[0]) - log.info("Adding tools:") - for t in tools_to_add: - log.info(f' - {t}') - log.info('') - adding_tools = inquirer.confirm("Add another tool?") +class TaskAgentSelectionView(AgentSelectionView): + title = "Select an Agent for your Task" - return tools_to_add + def submit(self): + if not self.agent_key.value: + self.error("Agent is required.") + return + self.app.state.update_active_task(agent=self.agent_key.value) + self.app.advance() -def ask_project_details(slug_name: Optional[str] = None) -> dict: - name = inquirer.text(message="What's the name of your project (snake_case)", default=slug_name or '') - if not is_snake_case(name): - log.error("Project name must be snake case") - return ask_project_details(slug_name) +class AfterTaskView(BannerView): + title = "Let there be tasks!" + sparkle = "(ノ ˘_˘)ノ ζ|||ζ ζ|||ζ ζ|||ζ" + subtitle = "Tasks are the heart of your agent's work. " - questions = inquirer.prompt( - [ - inquirer.Text("version", message="What's the initial version", default="0.1.0"), - inquirer.Text("description", message="Enter a description for your project"), - inquirer.Text("author", message="Who's the author (your name)?"), + +class DebugView(WizardView): + name = "debug" + + def layout(self) -> list[Renderable]: + from agentstack.utils import get_version + + return [ + Box( + (0, 0), + (self.height - 1, self.width), + color=COLOR_BORDER, + modules=[ + ColorWheel((1, 1)), + Title( + (self.height - 6, 3), + (1, self.width - 5), + color=COLOR_MAIN, + value=f"AgentStack version {get_version()}", + ), + Title( + (self.height - 4, 3), + (1, self.width - 5), + color=COLOR_MAIN, + value=f"Window size: {self.width}x{self.height}", + ), + ], + ), + HelpText((self.height - 1, 0), (1, self.width)), ] - ) - questions['name'] = name - return questions +class State: + project: dict[str, Any] + # `active_agent` is the agent we are currently working on + active_agent: str + # `active_task` is the task we are currently working on + active_task: str + # `tool_category` is a temporary value while an agent is being created + tool_category: str + # `agents` is a dictionary of agents we have created + agents: dict[str, dict] + # `tasks` is a dictionary of tasks we have created + tasks: dict[str, dict] + + def __init__(self): + self.project = {} + self.agents = {} + self.tasks = {} + + def __repr__(self): + return f"State(project={self.project}, agents={self.agents}, tasks={self.tasks})" + + def create_project(self, name: str, description: str): + self.project = { + 'name': name, + 'description': description, + } + + def update_framework(self, framework: str): + self.project['framework'] = framework + + def create_agent(self, name: str, role: str, goal: str, backstory: str): + self.agents[name] = { + 'role': role, + 'goal': goal, + 'backstory': backstory, + 'llm': None, + 'tools': [], + } + self.active_agent = name + + def update_active_agent(self, **kwargs): + agent = self.agents[self.active_agent] + for key, value in kwargs.items(): + agent[key] = value + + def update_active_agent_tools(self, tool_name: str): + self.agents[self.active_agent]['tools'].append(tool_name) + + def create_task(self, name: str, description: str, expected_output: str): + self.tasks[name] = { + 'description': description, + 'expected_output': expected_output, + } + self.active_task = name + + def update_active_task(self, **kwargs): + task = self.tasks[self.active_task] + for key, value in kwargs.items(): + task[key] = value + + def to_template_config(self) -> TemplateConfig: + tools = [] + for agent_name, agent_data in self.agents.items(): + for tool_name in agent_data['tools']: + tools.append( + TemplateConfig.Tool( + name=tool_name, + agents=[agent_name], + ) + ) + + return TemplateConfig( + template_version=4, + name=self.project['name'], + description=self.project['description'], + framework=self.project['framework'], + method="sequential", + agents=[ + TemplateConfig.Agent( + name=agent_name, + role=agent_data['role'], + goal=agent_data['goal'], + backstory=agent_data['backstory'], + llm=agent_data['llm'], + ) + for agent_name, agent_data in self.agents.items() + ], + tasks=[ + TemplateConfig.Task( + name=task_name, + description=task_data['description'], + expected_output=task_data['expected_output'], + agent=self.active_agent, + ) + for task_name, task_data in self.tasks.items() + ], + tools=tools, + ) + + +class WizardApp(App): + views = { + 'welcome': BannerView, + 'framework': FrameworkView, + 'project': ProjectView, + 'after_project': AfterProjectView, + 'agent': AgentView, + 'model': ModelView, + 'tool_agent_selection': ToolAgentSelectionView, + 'tool_category': ToolCategoryView, + 'tool': ToolView, + 'after_agent': AfterAgentView, + 'task': TaskView, + 'task_agent_selection': TaskAgentSelectionView, + 'after_task': AfterTaskView, + 'debug': DebugView, + } + shortcuts = { + 'd': 'debug', + } + workflow = { + 'project': [ # initialize a project + 'welcome', + 'project', + 'framework', + 'after_project', + ], + 'agent': [ # add agents + 'agent', + 'model', + 'tool_category', + 'tool', + 'after_agent', + ], + 'task': [ # add tasks + 'task', + 'task_agent_selection', + 'after_task', + ], + 'tool': [ # add tools to an agent + 'tool_agent_selection', + 'tool_category', + 'tool', + 'after_agent', + ], + } + + state: State + active_workflow: Optional[str] + active_view: Optional[str] + + min_width: int = 80 + min_height: int = 24 + + # the main loop can still execute once more after this; so we create an + # explicit marker to ensure the template is only written once + _finish_run_once: bool = True + + def start(self): + """Load the first view in the default workflow.""" + view = self.workflow['project'][0] + self.load(view, workflow='project') + + def finish(self): + """Create the project, write the config file, and exit.""" + template = self.state.to_template_config() + + self.stop() + + if self._finish_run_once: + self._finish_run_once = False + log.set_stdout(sys.stdout) # re-enable on-screen logging + + init_project( + slug_name=template.name, + template_data=template, + ) + + template.write_to_file(conf.PATH / "wizard") + log.info(f"Saved template to: {conf.PATH / 'wizard.json'}") + + def advance(self, steps: int = 1): + """Load the next view in the active workflow.""" + assert self.active_workflow, "No active workflow set." + assert self.active_view, "No active view set." + + workflow = self.workflow[self.active_workflow] + current_index = workflow.index(self.active_view) + view = workflow[current_index + steps] + self.load(view, workflow=self.active_workflow) + + def back(self): + """Load the previous view in the active workflow.""" + return self.advance(-1) + + def load(self, view: str, workflow: Optional[str] = None): + """Load a view from a workflow.""" + self.active_workflow = workflow if workflow else self.active_workflow + self.active_view = view + super().load(view) + + @classmethod + def wrapper(cls, stdscr): + app = cls(stdscr) + app.state = State() + + app.start() + app.run() + + +def main(): + import io + log.set_stdout(io.StringIO()) # disable on-screen logging + curses.wrapper(WizardApp.wrapper) diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py index 8c45e41d..90c0667e 100644 --- a/agentstack/frameworks/__init__.py +++ b/agentstack/frameworks/__init__.py @@ -54,6 +54,8 @@ class FrameworkModule(Protocol): Protocol spec for a framework implementation module. """ + NAME: str # Human readable name of the framework + DESCRIPTION: str # Human readable description of the framework ENTRYPOINT: Path """ Relative path to the entrypoint file for the framework in the user's project. @@ -302,6 +304,17 @@ def get_framework_module(framework: str) -> FrameworkModule: raise Exception(f"Framework {framework} could not be imported.") +def get_framework_info(framework: str) -> dict[str, str]: + """ + Get the info for a framework. + """ + _module = get_framework_module(framework) + return { + 'name': _module.NAME, + 'description': _module.DESCRIPTION, + } + + def get_entrypoint_path(framework: str) -> Path: """ Get the path to the entrypoint file for a framework. diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index 86e4f0e0..283aa3a5 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -13,6 +13,10 @@ NAME: str = "CrewAI" +DESCRIPTION: str = ( + "Framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative " + "intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks." +) ENTRYPOINT: Path = Path('src/crew.py') diff --git a/agentstack/frameworks/langgraph.py b/agentstack/frameworks/langgraph.py index 17540863..397a976c 100644 --- a/agentstack/frameworks/langgraph.py +++ b/agentstack/frameworks/langgraph.py @@ -13,8 +13,11 @@ from agentstack.tasks import TaskConfig, get_all_task_names from agentstack import graph - NAME: str = "LangGraph" +DESCRIPTION: str = ( + "A library for building stateful, multi-actor applications with LLMs, used to create " + "agent and multi-agent workflows." +) ENTRYPOINT: Path = Path('src/graph.py') GRAPH_NODE_START = 'START' diff --git a/agentstack/frameworks/llamaindex.py b/agentstack/frameworks/llamaindex.py index 269dd4b4..84dc77f8 100644 --- a/agentstack/frameworks/llamaindex.py +++ b/agentstack/frameworks/llamaindex.py @@ -12,6 +12,9 @@ from agentstack import graph NAME: str = "LLamaIndex" +DESCRIPTION: str = ( + "LlamaIndex is the leading framework for building LLM-powered agents over your data." +) ENTRYPOINT: Path = Path('src/stack.py') PROVIDERS = { diff --git a/agentstack/frameworks/openai_swarm.py b/agentstack/frameworks/openai_swarm.py index 026d1a6b..a365ef6d 100644 --- a/agentstack/frameworks/openai_swarm.py +++ b/agentstack/frameworks/openai_swarm.py @@ -13,6 +13,10 @@ NAME: str = "OpenAI Swarm" +DESCRIPTION: str = ( + "Educational framework exploring ergonomic, lightweight multi-agent orchestration. " + "Managed by OpenAI Solution team." +) ENTRYPOINT: Path = Path('src/stack.py') diff --git a/agentstack/main.py b/agentstack/main.py index 3139a9a5..adb5cb50 100644 --- a/agentstack/main.py +++ b/agentstack/main.py @@ -14,6 +14,7 @@ run_project, export_template, ) +from agentstack.cli import wizard from agentstack.telemetry import track_cli_command, update_telemetry from agentstack.utils import get_version, term_color from agentstack import generation @@ -68,10 +69,14 @@ def _main(): "init", aliases=["i"], help="Initialize a directory for the project", parents=[global_parser] ) init_parser.add_argument("slug_name", nargs="?", help="The directory name to place the project in") - init_parser.add_argument("--wizard", "-w", action="store_true", help="Use the setup wizard") + init_parser.add_argument("--wizard", "-w", action="store_true", help="Use the setup wizard [deprecated]") init_parser.add_argument("--template", "-t", help="Agent template to use") init_parser.add_argument("--framework", "-f", help="Framework to use") + wizard_parser = subparsers.add_parser( + "wizard", help="Run the setup wizard", parents=[global_parser] + ) + # 'run' command run_parser = subparsers.add_parser( "run", @@ -186,7 +191,13 @@ def _main(): elif args.command in ["templates"]: webbrowser.open("https://docs.agentstack.sh/quickstart") elif args.command in ["init", "i"]: - init_project(args.slug_name, args.template, args.framework, args.wizard) + if args.wizard: + log.warning("init --wizard is deprecated. Use `agentstack wizard`") + wizard.main() + else: + init_project(args.slug_name, args.template, args.framework) + elif args.command in ["wizard"]: + wizard.main() elif args.command in ["tools", "t"]: if args.tools_command in ["list", "l"]: list_tools() diff --git a/agentstack/providers.py b/agentstack/providers.py index f433d59e..48987e51 100644 --- a/agentstack/providers.py +++ b/agentstack/providers.py @@ -1,5 +1,66 @@ +from typing import Optional +import re +import pydantic from agentstack.exceptions import ValidationError +# model ids follow LiteLLM format +PREFERRED_MODELS = { + 'groq/deepseek-r1-distill-llama-70b': { + 'name': "DeepSeek R1 Distill Llama 70B", + 'host': "Groq", + 'description': "The Groq DeepSeek R1 Distill Llama 70B model", + }, + 'deepseek/deepseek-reasoner': { + 'name': "DeepSeek Reasoner", + 'host': "DeepSeek", + 'description': "The DeepSeek Reasoner model hosted by DeepSeek", + }, + 'openai/o1-preview': { + 'name': "o1 Preview", + 'host': "OpenAI", + 'description': "The OpenAI o1 Preview model", + }, + 'anthropic/claude-3-5-sonnet': { + 'name': "Claude 3.5 Sonnet", + 'host': "Anthropic", + 'description': "The Anthropic Claude 3.5 Sonnet model", + }, + # TODO there is no publicly available OpenRouter implementation for + # LangChain, so we can't recommend this yet. + # 'openrouter/deepseek/deepseek-r1': { + # 'name': "DeepSeek R1", + # 'host': "OpenRouter", + # 'description': "The DeepSeek R1 model hosted by OpenRouter", + # }, + 'openai/gpt-4o': { + 'name': "GPT-4o", + 'host': "OpenAI", + 'description': "The OpenAI GPT-4o model", + }, + 'anthropic/claude-3-opus': { + 'name': "Claude 3 Opus", + 'host': "Anthropic", + 'description': "The Anthropic Claude 3 Opus model", + }, +} + + +class ProviderConfig(pydantic.BaseModel): + id: str + name: Optional[str] + host: Optional[str] + description: Optional[str] + provider = property(lambda self: parse_provider_model(self.id)[0]) + model = property(lambda self: parse_provider_model(self.id)[1]) + + +def get_preferred_models() -> list[ProviderConfig]: + return [ProviderConfig(id=model_id, **model) for model_id, model in PREFERRED_MODELS.items()] + + +def get_preferred_model_ids() -> list[str]: + return [model.id for model in get_preferred_models()] + def parse_provider_model(model_id: str) -> tuple[str, str]: """Parse the provider and model name from the model ID""" @@ -12,4 +73,3 @@ def parse_provider_model(model_id: str) -> tuple[str, str]: return '/'.join(parts[:2]), parts[2] else: raise ValidationError(f"Model id '{model_id}' does not match expected format.") - diff --git a/agentstack/templates/__init__.py b/agentstack/templates/__init__.py index 1602fc56..7c450b8b 100644 --- a/agentstack/templates/__init__.py +++ b/agentstack/templates/__init__.py @@ -178,16 +178,16 @@ class TemplateConfig(pydantic.BaseModel): class Agent(pydantic.BaseModel): name: str - role: str - goal: str - backstory: str + role: Optional[str] + goal: Optional[str] + backstory: Optional[str] allow_delegation: bool = False llm: str class Task(pydantic.BaseModel): name: str - description: str - expected_output: str + description: Optional[str] + expected_output: Optional[str] agent: str # TODO this is redundant with the graph class Tool(pydantic.BaseModel): diff --git a/agentstack/tui.py b/agentstack/tui.py new file mode 100644 index 00000000..568e6600 --- /dev/null +++ b/agentstack/tui.py @@ -0,0 +1,1376 @@ +import curses +import signal +import time +import math +from random import randint +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union, Callable, Any +from enum import Enum +from pyfiglet import Figlet + +from agentstack import conf, log + + +class RenderException(Exception): + pass + + +# horizontal alignment +ALIGN_LEFT = "left" +ALIGN_CENTER = "center" +ALIGN_RIGHT = "right" + +# vertical alignment +ALIGN_TOP = "top" +ALIGN_MIDDLE = "middle" +ALIGN_BOTTOM = "bottom" + +# module positioning +POS_RELATIVE = "relative" +POS_ABSOLUTE = "absolute" + + +class Node: + """ + A simple data node that can be updated and have callbacks. This is used to + populate and retrieve data from an input field inside the user interface. + """ + + value: Any + callbacks: list[Callable] + + def __init__(self, value: Any = "") -> None: + self.value = value + self.callbacks = [] + + def __str__(self): + return str(self.value) + + def update(self, value: Any) -> None: + self.value = value + for callback in self.callbacks: + callback(self) + + def add_callback(self, callback): + self.callbacks.append(callback) + + def remove_callback(self, callback): + self.callbacks.remove(callback) + + +class Key: + """ + Conversions and convenience methods for key codes. + + Provides booleans about the key pressed: + + `key.BACKSPACE` + `key.is_numeric` + `key.is_alpha` + ... + """ + + const = { + 'UP': 259, + 'DOWN': 258, + 'BACKSPACE': 127, + 'TAB': 9, + 'ESC': 27, + 'ENTER': 10, + 'SPACE': 32, + 'PERIOD': 46, + 'PERCENT': 37, + 'MINUS': 45, + } + + def __init__(self, ch: int): + self.ch = ch + + def __getattr__(self, name) -> bool: + try: + return self.ch == self.const[name] + except KeyError: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + @property + def chr(self): + return chr(self.ch) + + @property + def is_numeric(self) -> bool: + return self.ch >= 48 and self.ch <= 57 + + @property + def is_alpha(self) -> bool: + return self.ch >= 65 and self.ch <= 122 + + +class Color: + """ + Color class based on HSV color space, mapping directly to terminal color capabilities. + + Hue: 0-360 degrees, mapped to 6 primary directions (0, 60, 120, 180, 240, 300) + Saturation: 0-100%, mapped to 6 levels (0, 20, 40, 60, 80, 100) + Value: 0-100%, mapped to 6 levels for colors, 24 levels for grayscale + """ + + # TODO: fallback for 16 color mode + # TODO: fallback for no color mode + BACKGROUND = curses.COLOR_BLACK + SATURATION_LEVELS = 12 + HUE_SEGMENTS = 6 + VALUE_LEVELS = 6 + GRAYSCALE_LEVELS = 24 + COLOR_CUBE_SIZE = 6 # 6x6x6 color cube + + reversed: bool = False + bold: bool = False + + FALLBACK_COLORS = [ + curses.COLOR_WHITE, + curses.COLOR_RED, + curses.COLOR_GREEN, + curses.COLOR_YELLOW, + curses.COLOR_BLUE, + curses.COLOR_MAGENTA, + curses.COLOR_CYAN, + ] + + _color_map = {} # cache for color mappings + COLOR_SUPPORT = "none" + + def __init__( + self, h: float, s: float = 100, v: float = 100, reversed: bool = False, bold: bool = False + ) -> None: + """ + Initialize color with HSV values. + + Args: + h: Hue (0-360 degrees) + s: Saturation (0-100 percent) + v: Value (0-100 percent) + """ + self.h = h % 360 + self.s = max(0, min(100, s)) + self.v = max(0, min(100, v)) + self.reversed = reversed + self.bold = bold + self._pair_number: Optional[int] = None + + def _get_closest_color(self) -> int: + """Map HSV to closest available terminal color number.""" + # Handle grayscale case + if self.s < 10: + gray_val = int(self.v * (self.GRAYSCALE_LEVELS - 1) / 100) + return 232 + gray_val if gray_val < self.GRAYSCALE_LEVELS else 231 + + # Convert HSV to the COLOR_CUBE_SIZE x COLOR_CUBE_SIZE x COLOR_CUBE_SIZE color cube + h = self.h + s = self.s / 100 + v = self.v / 100 + + # Map hue to primary and secondary colors (0 to HUE_SEGMENTS-1) + h = (h + 330) % 360 # -30 degrees = +330 degrees + h_segment = int((h / 60) % self.HUE_SEGMENTS) + h_remainder = (h % 60) / 60 + + # Get RGB values based on hue segment + max_level = self.COLOR_CUBE_SIZE - 1 + if h_segment == 0: # Red to Yellow + r, g, b = max_level, int(max_level * h_remainder), 0 + elif h_segment == 1: # Yellow to Green + r, g, b = int(max_level * (1 - h_remainder)), max_level, 0 + elif h_segment == 2: # Green to Cyan + r, g, b = 0, max_level, int(max_level * h_remainder) + elif h_segment == 3: # Cyan to Blue + r, g, b = 0, int(max_level * (1 - h_remainder)), max_level + elif h_segment == 4: # Blue to Magenta + r, g, b = int(max_level * h_remainder), 0, max_level + else: # Magenta to Red + r, g, b = max_level, 0, int(max_level * (1 - h_remainder)) + + # Apply saturation + max_rgb = max(r, g, b) + if max_rgb > 0: + # Map the saturation to the number of levels + s_level = int(s * (self.SATURATION_LEVELS - 1)) + s_factor = s_level / (self.SATURATION_LEVELS - 1) + + r = int(r + (max_level - r) * (1 - s_factor)) + g = int(g + (max_level - g) * (1 - s_factor)) + b = int(b + (max_level - b) * (1 - s_factor)) + + # Apply value (brightness) + v = max(0, min(max_level, int(v * self.VALUE_LEVELS))) + r = min(max_level, int(r * v / max_level)) + g = min(max_level, int(g * v / max_level)) + b = min(max_level, int(b * v / max_level)) + + # Convert to color cube index (16-231) + return int(16 + (r * self.COLOR_CUBE_SIZE * self.COLOR_CUBE_SIZE) + (g * self.COLOR_CUBE_SIZE) + b) + + def hue(self, h: float) -> 'Color': + """Set the hue of the color.""" + return Color(h, self.s, self.v, self.reversed, self.bold) + + def sat(self, s: float) -> 'Color': + """Set the saturation of the color.""" + return Color(self.h, s, self.v, self.reversed, self.bold) + + def val(self, v: float) -> 'Color': + """Set the value of the color.""" + return Color(self.h, self.s, v, self.reversed, self.bold) + + def reverse(self) -> 'Color': + """Set the reversed attribute of the color.""" + return Color(self.h, self.s, self.v, True, self.bold) + + def _get_color_pair(self, pair_number: int) -> int: + """Apply reversing to the color pair.""" + pair = curses.color_pair(pair_number) + if self.reversed: + pair = pair | curses.A_REVERSE + if self.bold: + pair = pair | curses.A_BOLD + return pair + + def _get_fallback_color(self): + hue = self.h + if self.s <= 50 or self.v <= 50: + return curses.COLOR_WHITE + + if hue < 30 or hue >= 330: + return curses.COLOR_RED + elif 30 < hue <= 90: + return curses.COLOR_YELLOW + elif 90 < hue <= 150: + return curses.COLOR_GREEN + elif 150 < hue <= 230: + return curses.COLOR_CYAN + elif 230 < hue <= 270: + return curses.COLOR_BLUE + elif 270 < hue <= 330: + return curses.COLOR_MAGENTA + else: + return curses.COLOR_WHITE + + def to_curses(self) -> int: + """Get curses color pair for this color.""" + if self._pair_number is not None: + return self._get_color_pair(self._pair_number) + + if Color.COLOR_SUPPORT == "none": + return 0 + + if Color.COLOR_SUPPORT == "basic": + color_number = self._get_fallback_color() + else: + color_number = self._get_closest_color() + + # Create new pair if needed + if color_number not in self._color_map: + pair_number = len(self._color_map) + 1 + curses.init_pair(pair_number, color_number, self.BACKGROUND) + self._color_map[color_number] = pair_number + else: + pair_number = self._color_map[color_number] + + self._pair_number = pair_number + return self._get_color_pair(pair_number) + + @classmethod + def initialize(cls) -> None: + """Initialize terminal color support.""" + cls._color_map = {} + cls.COLOR_SUPPORT = "none" + + if not curses.has_colors(): + return + + try: + curses.start_color() + curses.use_default_colors() + curses.init_pair(1, 1, -1) + curses.color_pair(1) + cls.COLOR_SUPPORT = "full" if curses.COLORS >= 256 else "basic" + except: + pass + + +class ColorAnimation(Color): + """ + Animate between two colors over a duration. + + Compatible interface with `Color` to add animation to element's color. + """ + + start: Color + end: Color + reversed: bool = False + bold: bool = False + duration: float + loop: bool + _start_time: float + + def __init__(self, start: Color, end: Color, duration: float, loop: bool = False): + super().__init__(start.h, start.s, start.v) + self.start = start + self.end = end + self.duration = duration + self.loop = loop + self._start_time = time.time() + + def reset_animation(self): + self._start_time = time.time() + + def to_curses(self) -> int: + if self.reversed: + self.start.reversed = True + self.end.reversed = True + elif self.start.reversed: + self.reversed = True + + if self.bold: + self.start.bold = True + self.end.bold = True + elif self.start.bold: + self.bold = True + + elapsed = time.time() - self._start_time + if elapsed > self.duration: + if self.loop: + self.start, self.end = self.end, self.start + self.reset_animation() + return self.start.to_curses() # prevents flickering :shrug: + else: + return self.end.to_curses() + + t = elapsed / self.duration + h1, h2 = self.start.h, self.end.h + # take the shortest path + diff = h2 - h1 + if abs(diff) > 180: + if diff > 0: + h1 += 360 + else: + h2 += 360 + h = (h1 + t * (h2 - h1)) % 360 + + # saturation and value + s = self.start.s + t * (self.end.s - self.start.s) + v = self.start.v + t * (self.end.v - self.start.v) + + return Color(h, s, v, reversed=self.reversed, bold=self.bold).to_curses() + + +class Renderable: + """ + A base class for all renderable modules. + + Handles sizing, positioning, and inserting the module into the grid. + """ + + _grid: Optional[curses.window] = None + y: int + x: int + height: int + width: int + parent: Optional['Contains'] = None + h_align: str = ALIGN_LEFT + v_align: str = ALIGN_TOP + color: Color + last_render: float = 0 + padding: tuple[int, int] = (1, 1) + positioning: str = POS_ABSOLUTE + + def __init__(self, coords: tuple[int, int], dims: tuple[int, int], color: Optional[Color] = None): + self.y, self.x = coords + self.height, self.width = dims + self.color = color or Color(0, 100, 0) + + def __repr__(self): + return f"{type(self)} at ({self.y}, {self.x})" + + @property + def grid(self): + # TODO validate that coords and size are within the parent window and give + # an explanatory error message. + if not self._grid: + if self.parent: + if self.positioning == POS_RELATIVE: + grid_func = self.parent.grid.derwin + elif self.positioning == POS_ABSOLUTE: + grid_func = self.parent.grid.subwin + else: + raise ValueError("Invalid positioning value") + else: + grid_func = curses.newwin + + self._grid = grid_func( + self.height + self.padding[0], self.width + self.padding[1], self.y, self.x + ) # TODO this cant be bigger than the window + self._grid.bkgd(' ', curses.color_pair(1)) + return self._grid + + def move(self, y: int, x: int): + """Move the module's grid to a new position.""" + self.y, self.x = y, x + if self._grid: + if self.positioning == POS_RELATIVE: + self._grid.mvderwin(self.y, self.x) + elif self.positioning == POS_ABSOLUTE: + self._grid.mvwin(self.y, self.x) + else: + raise ValueError("Cannot move a root window") + + @property + def abs_x(self): + """Absolute X coordinate of this module""" + if self.parent and not self.positioning == POS_ABSOLUTE: + return self.x + self.parent.abs_x + return self.x + + @property + def abs_y(self): + """Absolute Y coordinate of this module""" + if self.parent and not self.positioning == POS_ABSOLUTE: + return self.y + self.parent.abs_y + return self.y + + def render(self): + """Render the module to the screen.""" + pass + + def hit(self, y, x): + """Is the mouse click inside this module?""" + return ( + y >= self.abs_y + and y < self.abs_y + self.height + and x >= self.abs_x + and x < self.abs_x + self.width + ) + + def click(self, y, x): + """Handle mouse click event.""" + pass + + def input(self, key: Key): + """Handle key input event.""" + pass + + def destroy(self) -> None: + if self._grid: + self._grid.erase() + self._grid.refresh() + self._grid = None + + +class Element(Renderable): + positioning: str = POS_RELATIVE + word_wrap: bool = False + + def __init__( + self, + coords: tuple[int, int], + dims: tuple[int, int], + value: Optional[Any] = "", + color: Optional[Color] = None, + ): + super().__init__(coords, dims, color=color) + self.value = value + + def __repr__(self): + return f"{type(self)} at ({self.y}, {self.x}) with value '{self.value[:20]}'" + + def _get_lines(self, value: str) -> list[str]: + """ + Get the lines to render. + + Called by `render()` using the value of the element. This allows us to have + word wrapping and alignment in all module types. + """ + if self.word_wrap: + splits = [''] * self.height + words = value.split() + for i in range(self.height): + while words and (len(splits[i]) + len(words[0]) + 1) <= self.width: + splits[i] += f"{words.pop(0)} " if words else '' + elif '\n' in value: + splits = value.split('\n') + else: + splits = [ + value, + ] + + if self.v_align == ALIGN_TOP: + # add empty elements below + splits = splits + [''] * (self.height - len(splits)) + elif self.v_align == ALIGN_MIDDLE: + # add empty elements before and after the splits to center it + pad = (self.height // 2) - (len(splits) // 2) + splits = [''] * pad + splits + [''] * pad + elif self.v_align == ALIGN_BOTTOM: + splits = [''] * (self.height - len(splits)) + splits + + lines = [] + for line in splits: + if self.h_align == ALIGN_LEFT: + line = line.ljust(self.width) + elif self.h_align == ALIGN_RIGHT: + line = line.rjust(self.width) + elif self.h_align == ALIGN_CENTER: + line = line.center(self.width) + + lines.append(line[: self.width]) + return lines + + def render(self): + for i, line in enumerate(self._get_lines(str(self.value))): + try: + self.grid.addstr(i, 0, line, self.color.to_curses()) + except curses.error: + pass # ignore overflow + + +class NodeElement(Element): + """ + A module that is bound to a node and updates when the node changes. + """ + + format: Optional[Callable] = None + + def __init__( + self, + coords: tuple[int, int], + dims: tuple[int, int], + node: Node, + color: Optional[Color] = None, + ): + super().__init__(coords, dims, color=color) + self.node = node + self.value = str(node) + self.node.add_callback(self.update) # allow the node to listen for changes + + def update(self, node: Node): + self.value = str(node) + + def save(self): + self.node.update(self.value) + + def destroy(self): + self.node.remove_callback(self.update) + super().destroy() + + +class Editable(NodeElement): + """ + A module that can be edited by the user. + + Handles mouse clicks, key input, and managing global editing state. + """ + + active: bool + _original_value: Any + + def __init__( + self, + coords, + dims, + node: Node, + color=None, + ): + super().__init__(coords, dims, node=node, color=color) + self.active = False + self._original_value = self.value + + def click(self, y, x): + if not self.active and self.hit(y, x): + self.activate() + elif self.active: # click off + self.deactivate(save=False) + + def activate(self): + """Make this module the active one; ie. editing or selected.""" + App.editing = True + self.active = True + self._original_value = self.value + + def deactivate(self, save: bool = True): + """Deactivate this module, making it no longer active.""" + App.editing = False + self.active = False + if save: + self.save() + + def input(self, key: Key): + if not self.active: + return + + if key.is_alpha or key.is_numeric or key.PERIOD or key.MINUS or key.SPACE: + self.value = str(self.value) + key.chr + elif key.BACKSPACE: + self.value = str(self.value)[:-1] + elif key.ESC: + self.deactivate(save=False) + self.value = self._original_value # revert changes + elif key.ENTER: + self.deactivate() + + def destroy(self): + self.deactivate() + super().destroy() + + +class TextInput(Editable): + """ + A module that allows the user to input text. + """ + + H, V, BR = "━", "┃", "┛" + padding: tuple[int, int] = (2, 1) + border_color: Color + active_color: Color + placeholder: str = "" + word_wrap: bool = True + + def __init__( + self, + coords: tuple[int, int], + dims: tuple[int, int], + node: Node, + placeholder: str = "", + color: Optional[Color] = None, + border: Optional[Color] = None, + active: Optional[Color] = None, + ): + super().__init__(coords, dims, node=node, color=color) + self.width, self.height = (dims[1] - 1, dims[0] - 1) + self.border_color = border or self.color + self.active_color = active or self.color + self.placeholder = placeholder + if self.value == "": + self.value = self.placeholder + + def activate(self): + # change the border color to a highlight + self._original_border_color = self.border_color + self.border_color = self.active_color + if self.value == self.placeholder: + self.value = "" + super().activate() + + def deactivate(self, save: bool = True): + if self.active and hasattr(self, '_original_border_color'): + self.border_color = self._original_border_color + if self.value == "": + self.value = self.placeholder + super().deactivate(save) + + def save(self): + if self.value != self.placeholder: + super().save() + + def render(self) -> None: + if self.value == self.placeholder: + color = self.color.to_curses() | curses.A_ITALIC + else: + color = self.color.to_curses() + for i, line in enumerate(self._get_lines(str(self.value))): + self.grid.addstr(i, 0, line, color) + + # # add border to bottom right like a drop shadow + for x in range(self.width): + self.grid.addch(self.height, x, self.H, self.border_color.to_curses()) + for y in range(self.height): + self.grid.addch(y, self.width, self.V, self.border_color.to_curses()) + self.grid.addch(self.height, self.width, self.BR, self.border_color.to_curses()) + + +class Text(Element): + """Basic text module""" + + pass + + +class WrappedText(Text): + """Text module with word wrapping""" + + word_wrap: bool = True + + +class ASCIIText(Text): + """Text module that renders as ASCII art""" + + default_font: str = "pepper" + formatter: Figlet + _ascii_render: Optional[str] = None # rendered content + _ascii_value: Optional[str] = None # value used to render content + + def __init__( + self, + coords: tuple[int, int], + dims: tuple[int, int], + value: Optional[Any] = "", + color: Optional[Color] = None, + formatter: Optional[Figlet] = None, + ): + super().__init__(coords, dims, value=value, color=color) + self.formatter = formatter or Figlet(font=self.default_font) + + def _get_lines(self, value: str) -> list[str]: + if not self._ascii_render or self._ascii_value != value: + # prevent rendering on every frame + self._ascii_value = value + self._ascii_render = self.formatter.renderText(value) or "" + + return super()._get_lines(self._ascii_render) + + +class BoldText(Text): + """Text module with bold text""" + + def __init__( + self, + coords: tuple[int, int], + dims: tuple[int, int], + value: Optional[Any] = "", + color: Optional[Color] = None, + ): + super().__init__(coords, dims, value=value, color=color) + self.color.bold = True + + +class Title(BoldText): + """A title module; shortcut for bold, centered text""" + + h_align: str = ALIGN_CENTER + v_align: str = ALIGN_MIDDLE + + +class Button(Element): + """A clickable button module""" + + h_align: str = ALIGN_CENTER + v_align: str = ALIGN_MIDDLE + active: bool = False + selected: bool = False + highlight: Optional[Color] = None + on_confirm: Optional[Callable] = None + on_activate: Optional[Callable] = None + + def __init__( + self, + coords: tuple[int, int], + dims: tuple[int, int], + value: Optional[Any] = "", + color: Optional[Color] = None, + highlight: Optional[Color] = None, + on_confirm: Optional[Callable] = None, + on_activate: Optional[Callable] = None, + ): + super().__init__(coords, dims, value=value, color=color) + self.highlight = highlight or self.color.sat(50) + self.on_confirm = on_confirm + self.on_activate = on_activate + + def confirm(self): + """Handle button confirmation.""" + if self.on_confirm: + self.on_confirm() + + def activate(self): + """Make this module the active one; ie. editing or selected.""" + self.active = True + self._original_color = self.color + self.color = self.highlight or self.color + if hasattr(self.color, 'reset_animation'): + self.color.reset_animation() + if self.on_activate: + self.on_activate(self.value) + + def deactivate(self, save: bool = True): + """Deactivate this module, making it no longer active.""" + self.active = False + if hasattr(self, '_original_color'): + self.color = self._original_color + + def click(self, y, x): + if self.hit(y, x): + self.confirm() + + def input(self, key: Key): + """Handle key input event.""" + if not self.active: + return + + if key.ENTER or key.SPACE: + self.confirm() + + +class RadioButton(Button): + """A Button with an indicator that it is selected""" + + ON, OFF = "●", "○" + selected: bool = False + + def render(self): + super().render() + icon = self.ON if self.selected else self.OFF + self.grid.addstr(1, 2, icon, self.color.to_curses()) + + +class CheckButton(RadioButton): + """A Button with an indicator that it is selected""" + + ON, OFF = "■", "□" + + +class Contains(Renderable): + """A container for other modules""" + + _grid: Optional[curses.window] = None + y: int + x: int + positioning: str = POS_RELATIVE + padding: tuple[int, int] = (1, 0) + color: Color + last_render: float = 0 + parent: Optional['Contains'] = None + modules: list[Renderable] + + def __init__( + self, + coords: tuple[int, int], + dims: tuple[int, int], + modules: list[Renderable], + color: Optional[Color] = None, + ): + super().__init__(coords, dims, color=color) + self.modules = [] + for module in modules: + self.append(module) + + def append(self, module: Renderable): + module.parent = self + self.modules.append(module) + + def get_modules(self): + """Override this to filter displayed modules""" + return self.modules + + def render(self): + for module in self.get_modules(): + try: + module.render() + module.last_render = time.time() + module.grid.noutrefresh() + except RenderException: + pass # ignore overflow + self.last_render = time.time() + + def click(self, y, x): + for module in self.modules: + module.click(y, x) + + def input(self, key: Key): + for module in self.modules: + module.input(key) + + def destroy(self): + for module in self.modules: + module.destroy() + self.grid.erase() + self.grid.refresh() + + +class Box(Contains): + """A container with a border""" + + H, V, TL, TR, BL, BR = "─", "│", "┌", "┐", "└", "┘" + + def render(self) -> None: + w: int = self.width - 1 + h: int = self.height - 1 + + for x in range(1, w): + self.grid.addch(0, x, self.H, self.color.to_curses()) + self.grid.addch(h, x, self.H, self.color.to_curses()) + for y in range(1, h): + self.grid.addch(y, 0, self.V, self.color.to_curses()) + self.grid.addch(y, w, self.V, self.color.to_curses()) + self.grid.addch(0, 0, self.TL, self.color.to_curses()) + self.grid.addch(h, 0, self.BL, self.color.to_curses()) + self.grid.addch(0, w, self.TR, self.color.to_curses()) + self.grid.addch(h, w, self.BR, self.color.to_curses()) + + for module in self.get_modules(): + try: + module.render() + module.last_render = time.time() + module.grid.noutrefresh() + except RenderException: + pass # ignore overflow + self.last_render = time.time() + self.grid.noutrefresh() + + +class LightBox(Box): + """A Box with light borders""" + + pass + + +class HeavyBox(Box): + """A Box with heavy borders""" + + H, V, TL, TR, BL, BR = "━", "┃", "┏", "┓", "┗", "┛" + + +class DoubleBox(Box): + """A Box with double borders""" + + H, V, TL, TR, BL, BR = "═", "║", "╔", "╗", "╚", "╝" + + +class Select(Box): + """ + Build a select menu out of buttons. + """ + + UP, DOWN = "▲", "▼" + on_change: Optional[Callable] = None + on_select: Optional[Callable] = None + button_cls: type[Button] = Button + button_height: int = 3 + show_up: bool = False + show_down: bool = False + + def __init__( + self, + coords: tuple[int, int], + dims: tuple[int, int], + options: list[str], + color: Optional[Color] = None, + highlight: Optional[Color] = None, + on_change: Optional[Callable] = None, + on_select: Optional[Callable] = None, + ) -> None: + super().__init__(coords, dims, [], color=color) + self.highlight = highlight or Color(0, 100, 100) + self.options = options + self.on_change = on_change + self.on_select = on_select + + for i, option in enumerate(self.options): + self.append(self._get_button(i, option)) + self._mark_active(0) + + def _get_button(self, index: int, option: str) -> Button: + """Helper to create a button for an option""" + return self.button_cls( + ((index * self.button_height) + 1, 1), + (self.button_height, self.width - 2), + value=option, + color=self.color, + highlight=self.highlight, + on_activate=lambda _: self._button_on_activate(index, option), + ) + + def _button_on_activate(self, index: int, option: str): + """Callback for when a button is activated.""" + if self.on_change: + self.on_change(index, option) + + def _mark_active(self, index: int): + """Mark a submodule as active.""" + for module in self.modules: + assert hasattr(module, 'deactivate') + module.deactivate() + + active = self.modules[index] + assert hasattr(active, 'active') + if not active.active: + assert hasattr(active, 'activate') + active.activate() + + if self.on_change: + self.on_change(index, self.options[index]) + + def _get_active_index(self): + """Get the index of the active option.""" + for module in self.modules: + if module.active: + return self.modules.index(module) + return None + + def get_modules(self): + """Return a subset of modules to be rendered""" + # since we can't always render all of the buttons, return a subset + # that can be displayed in the available height. + num_displayed = (self.height - 3) // self.button_height + index = self._get_active_index() or 0 + count = len(self.modules) + + if count <= num_displayed: + start = 0 + self.show_up = False + else: + ideal_start = index - (num_displayed // 2) + start = min(ideal_start, count - num_displayed) + start = max(0, start) + self.show_up = bool(start > 0) + + end = min(start + num_displayed, count) + self.show_down = bool(end < count) + visible = self.modules[start:end] + + for i, module in enumerate(visible): + pad = 2 if self.show_up else 1 + module.move((i * self.button_height) + pad, module.x) + return visible + + def render(self): + """Render all options and conditionally show up/down arrows.""" + for module in self.modules: + if module.last_render: + module.grid.erase() + + self.grid.erase() + if self.show_up: + self.grid.addstr(1, 1, self.UP.center(self.width - 2), self.color.to_curses()) + if self.show_down: + self.grid.addstr(self.height - 2, 1, self.DOWN.center(self.width - 2), self.color.to_curses()) + + super().render() + + def select(self, option: Button): + """Select an option; ie. mark it as the value of this element.""" + index = self.modules.index(option) + option.selected = not option.selected + self.value = self.options[index] + self._mark_active(index) + if self.on_select: + self.on_select(index, self.options[index]) + + def input(self, key: Key): + """Handle key input event.""" + index = self._get_active_index() + + if index is None: + return # can't select a non-active element + + if key.UP or key.DOWN: + direction = -1 if key.UP else 1 + index = direction + index + if index < 0 or index >= len(self.modules): + return # don't loop + self._mark_active(index) + elif key.SPACE or key.ENTER: + self.select(self.modules[index]) + + super().input(key) + + def click(self, y, x): + for module in self.get_modules(): + if not module.hit(y, x): + continue + self.select(module) + + +class RadioSelect(Select): + """Allow one button to be `selected` at a time""" + + button_cls = RadioButton + + def __init__( + self, + coords: tuple[int, int], + dims: tuple[int, int], + options: list[str], + color: Optional[Color] = None, + highlight: Optional[Color] = None, + on_change: Optional[Callable] = None, + on_select: Optional[Callable] = None, + ) -> None: + super().__init__( + coords, dims, options, color=color, highlight=highlight, on_change=on_change, on_select=on_select + ) + self.select(self.modules[0]) # type: ignore[arg-type] + + def select(self, module: Button): + """Radio buttons only allow a single selection.""" + for _module in self.modules: + assert hasattr(_module, 'selected') + _module.selected = False + super().select(module) + + +class MultiSelect(Select): + """Allow multiple buttons to be `selected` at a time""" + + button_cls = CheckButton + + +class ColorWheel(Element): + """ + A module used for testing color display. + """ + + width: int = 80 + height: int = 24 + + def __init__(self, coords: tuple[int, int], duration: float = 10.0): + super().__init__(coords, (self.height, self.width)) + self.duration = duration + self.start_time = time.time() + + def render(self) -> None: + self.grid.erase() + center_y, center_x = 12, 22 + radius = 10 + elapsed = time.time() - self.start_time + hue_offset = (elapsed / self.duration) * 360 # animate + + for y in range(center_y - radius, center_y + radius + 1): + for x in range(center_x - radius * 2, center_x + radius * 2 + 1): + # Convert position to polar coordinates + dx = (x - center_x) / 2 # Compensate for terminal character aspect ratio + dy = y - center_y + distance = math.sqrt(dx * dx + dy * dy) + + if distance <= radius: + # Convert to HSV + angle = math.degrees(math.atan2(dy, dx)) + # h = (angle + 360) % 360 + h = (angle + hue_offset) % 360 + s = (distance / radius) * 100 + v = 100 # (distance / radius) * 100 + + color = Color(h, s, v) + self.grid.addstr(y, x, "█", color.to_curses()) + + x = 50 + y = 4 + for i in range(0, curses.COLORS): + self.grid.addstr(y, x, f"███", curses.color_pair(i + 1)) + y += 1 + if y >= self.height - 4: + y = 4 + x += 3 + if x >= self.width - 3: + break + + self.grid.refresh() + + +class DebugElement(Element): + """Show fps and color usage.""" + + def __init__(self, coords: tuple[int, int]): + super().__init__(coords, (1, 40)) + + def render(self) -> None: + self.grid.addstr(0, 1, f"FPS: {1 / (time.time() - self.last_render):.0f}") + self.grid.addstr(0, 10, f"Colors: {len(Color._color_map)}/{curses.COLORS}") + self.grid.addstr(0, 27, f"Dims: {self.parent.width}x{self.parent.height}") # type: ignore + + +class View(Contains): + app: 'App' + positioning: str = POS_ABSOLUTE + padding: tuple[int, int] = (0, 0) + y: int = 0 + x: int = 0 + + def __init__(self, app: 'App'): + self.app = app + self.modules = [] + + def init(self, dims: tuple[int, int]) -> None: + self.height, self.width = dims + self.modules = self.layout() + + if conf.DEBUG: + self.append(DebugElement((1, 1))) + + @property + def grid(self): + if not self._grid: + self._grid = curses.newwin(self.height, self.width, self.y, self.x) + self._grid.bkgd(' ', curses.color_pair(1)) + return self._grid + + def layout(self) -> list[Renderable]: + """Override this in subclasses to define the layout of the view.""" + log.warning(f"`layout` not implemented in View: {self.__class__}.") + return [] + + +class App: + """The main application class.""" + + stdscr: curses.window + height: int + width: int + min_height: int = 24 + min_width: int = 80 + frame_time: float = 1.0 / 60 # 30 FPS + editing = False + view: Optional[View] = None # the active view + views: dict[str, type[View]] = {} + shortcuts: dict[str, str] = {} + + def __init__(self, stdscr: curses.window) -> None: + self.stdscr = stdscr + self.height, self.width = self.stdscr.getmaxyx() + + if not self.width >= self.min_width or not self.height >= self.min_height: + raise RenderException( + f"Terminal window is too small. Resize to at least {self.min_width}x{self.min_height}. \n" + f"Current size: {self.width}x{self.height}" + ) + + curses.curs_set(0) + stdscr.nodelay(True) + stdscr.timeout(10) # balance framerate with cpu usage + curses.mousemask(curses.BUTTON1_CLICKED | curses.REPORT_MOUSE_POSITION) + + Color.initialize() + curses.init_pair(1, curses.COLOR_WHITE, curses.COLOR_BLACK) + + def add_view(self, name: str, view_cls: type[View], shortcut: Optional[str] = None) -> None: + self.views[name] = view_cls + if shortcut: + self.shortcuts[shortcut] = name + + def load(self, view_name: str): + if self.view: + self.view.destroy() + self.view = None + + view_cls = self.views[view_name] + self.view = view_cls(self) + self.view.init((self.height, self.width)) + + def run(self): + frame_time = 1.0 / 60 # 30 FPS + last_frame = time.time() + self._running = True + while self._running: + self.height, self.width = self.stdscr.getmaxyx() + + current_time = time.time() + delta = current_time - last_frame + ch = self.stdscr.getch() + + if ch == curses.KEY_MOUSE: + try: + _, x, y, _, bstate = curses.getmouse() + if not bstate & curses.BUTTON1_CLICKED: + continue # only allow left click + self.click(y, x) + except curses.error: + pass + elif ch != -1: + self.input(ch) + + if not App.editing: + if ch == ord('q'): + break + elif ch in [ord(x) for x in self.shortcuts.keys()]: + self.load(self.shortcuts[chr(ch)]) + + if delta >= self.frame_time or ch != -1: + self.render() + delta = 0 + if delta < self.frame_time: + time.sleep(frame_time - delta) + + def stop(self): + if self.view: + self.view.destroy() + self._running = False + curses.endwin() + + def render(self): + if not self.view: + return + + # handle resize + height, width = self.stdscr.getmaxyx() + if self.view.width != width or self.view.height != height: + self.width, self.height = width, height + for name, cls in self.views.items(): + if cls == self.view.__class__: + break + self.load(name) + + # render loop + try: + self.view.render() + self.view.last_render = time.time() + self.view.grid.noutrefresh() + curses.doupdate() + except curses.error as e: + log.debug(f"Error rendering view: {e}") + if "add_wch() returned ERR" in str(e): + raise RenderException("Grid not large enough to render all modules.") + if "curses function returned NULL" in str(e): + pass + # raise RenderException("Window not large enough to render.") + raise e + + def click(self, y, x): + """Handle mouse click event.""" + if self.view: + self.view.click(y, x) + + def input(self, ch: int): + """Handle key input event.""" + key = Key(ch) + + if key.TAB: + self._select_next_tabbable() + + if self.view: + self.view.input(key) + + def _get_tabbable_modules(self): + """ + Search through the tree of modules to find selectable elements. + """ + + def _get_activateable(module: Element): + """Find modules with an `activate` method""" + if hasattr(module, 'activate'): + yield module + for submodule in getattr(module, 'modules', []): + yield from _get_activateable(submodule) + + return list(_get_activateable(self.view)) + + def _select_next_tabbable(self): + """ + Activate the next tabbable module in the list. + """ + + def _get_active_module(module: Element): + if hasattr(module, 'active') and module.active: + return module + for submodule in getattr(module, 'modules', []): + active = _get_active_module(submodule) + if active: + return active + return None + + modules = self._get_tabbable_modules() + active_module = _get_active_module(self.view) + if active_module: + for module in modules: + module.deactivate() + next_index = modules.index(active_module) + 1 + if next_index >= len(modules): + next_index = 0 + modules[next_index].activate() # TODO this isn't working + elif modules: + modules[0].activate() diff --git a/docs/llms.txt b/docs/llms.txt index 2e0af5b7..fa23ea5c 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -514,10 +514,6 @@ which adheres to a common pattern or exporting your project to share. Templates are versioned, and each previous version provides a method to convert it's content to the current version. -> TODO: Templates are currently identified as `proj_templates` since they conflict -with the templates used by `generation`. Move existing templates to be part of -the generation package. - ### `TemplateConfig.from_user_input(identifier: str)` `` Returns a `TemplateConfig` object for either a URL, file path, or builtin template name. @@ -716,7 +712,7 @@ title: 'System Analyzer' description: 'Inspect a project directory and improve it' --- -[View Template](https://github.com/AgentOps-AI/AgentStack/blob/main/agentstack/templates/proj_templates/system_analyzer.json) +[View Template](https://github.com/AgentOps-AI/AgentStack/blob/main/agentstack/templates/system_analyzer.json) ```bash agentstack init --template=system_analyzer @@ -737,7 +733,7 @@ title: 'Researcher' description: 'Research and report result from a query' --- -[View Template](https://github.com/AgentOps-AI/AgentStack/blob/main/agentstack/templates/proj_templates/research.json) +[View Template](https://github.com/AgentOps-AI/AgentStack/blob/main/agentstack/templates/research.json) ```bash agentstack init --template=research @@ -828,7 +824,54 @@ title: 'Content Creator' description: 'Research a topic and create content on it' --- -[View Template](https://github.com/AgentOps-AI/AgentStack/blob/main/agentstack/templates/proj_templates/content_creator.json) +[View Template](https://github.com/AgentOps-AI/AgentStack/blob/main/agentstack/templates/content_creator.json) + +## frameworks/list.mdx + +--- +title: Frameworks +description: 'Supported frameworks in AgentStack' +icon: 'ship' +--- + +These are documentation links to the frameworks supported directly by AgentStack. + +To start a project with one of these frameworks, use +```bash +agentstack init --framework +``` + +## Framework Docs + + + An intuitive agentic framework (recommended) + + + A complex but capable framework with a _steep_ learning curve + + + A simple framework with a cult following + + + An expansive framework with many ancillary features + + ## tools/package-structure.mdx @@ -1043,7 +1086,7 @@ You can pass the `--wizard` flag to `agentstack init` to use an interactive proj You can also pass a `--template=` argument to `agentstack init` which will pre-populate your project with functionality from a built-in template, or one found on the internet. A `template_name` can be one of three identifiers: -- A built-in AgentStack template (see the `templates/proj_templates` directory in the AgentStack repo for bundled templates). +- A built-in AgentStack template (see the `templates` directory in the AgentStack repo for bundled templates). - A template file from the internet; pass the full https URL of the template. - A local template file; pass an absolute or relative path. diff --git a/pyproject.toml b/pyproject.toml index 5cbc9bc8..c7f2ace3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "agentops>=0.3.18", "typer>=0.12.5", "inquirer>=3.4.0", - "art>=6.3", + "pyfiglet==1.0.2", "toml>=0.10.2", "ruamel.yaml.base>=0.3.2", "cookiecutter==2.6.0", diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 96d4edf8..35abd9df 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -11,6 +11,7 @@ from agentstack._tools import ToolConfig, get_all_tools from agentstack.agents import AGENTS_FILENAME, AgentConfig from agentstack.tasks import TASKS_FILENAME, TaskConfig +from agentstack.providers import get_preferred_model_ids from agentstack import graph BASE_PATH = Path(__file__).parent @@ -144,6 +145,14 @@ def test_get_agent_tool_names(self): tool_names = frameworks.get_agent_tool_names('agent_name') assert tool_names == ['test_tool'] + @parameterized.expand([(x, ) for x in get_preferred_model_ids()]) + def test_add_agent_preferred_models(self, llm: str): + """Test adding an Agent to the graph with all preferred models we support""" + self._populate_min_entrypoint() + agent = self._get_test_agent() + agent.llm = llm + frameworks.add_agent(agent) + def test_add_tool(self): self._populate_max_entrypoint() frameworks.add_tool(self._get_test_tool(), 'agent_name') diff --git a/tests/test_providers.py b/tests/test_providers.py index 42d6e41a..e89b39e4 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,7 +1,11 @@ import unittest from agentstack.exceptions import ValidationError from agentstack.providers import ( + PREFERRED_MODELS, + ProviderConfig, parse_provider_model, + get_preferred_model_ids, + get_preferred_models, ) @@ -23,8 +27,16 @@ def test_parse_provider_model(self): ] for case, expect in zip(cases, expected): self.assertEqual(parse_provider_model(case), expect) - + def test_invalid_provider_model(self): with self.assertRaises(ValidationError): parse_provider_model("invalid_provider_model") + + def test_all_preferred_provider_config(self): + for model in get_preferred_models(): + self.assertIsInstance(model, ProviderConfig) + + def test_all_preferred_model_ids(self): + for model_id in get_preferred_model_ids(): + self.assertIsInstance(model_id, str) diff --git a/tests/test_tool_config.py b/tests/test_tool_config.py index bf187e44..bc2c3f67 100644 --- a/tests/test_tool_config.py +++ b/tests/test_tool_config.py @@ -2,7 +2,16 @@ import unittest import re from pathlib import Path -from agentstack._tools import ToolConfig, get_all_tool_paths, get_all_tool_names +from agentstack.exceptions import ValidationError +from agentstack._tools import ( + ToolConfig, + get_all_tools, + get_all_tool_paths, + get_all_tool_names, + ToolCategory, + get_all_tool_categories, + get_all_tool_category_names, +) BASE_PATH = Path(__file__).parent @@ -43,6 +52,17 @@ def test_dependency_versions(self): "All dependencies must include version specifications." ) + def test_tool_category(self): + categories = get_all_tool_categories() + assert categories + for category in categories: + assert category.name in get_all_tool_category_names() + assert isinstance(category, ToolCategory) + + def test_all_tools_have_valid_categories(self): + for tool_config in get_all_tools(): + assert tool_config.category in get_all_tool_category_names() + def test_all_json_configs_from_tool_name(self): for tool_name in get_all_tool_names(): config = ToolConfig.from_tool_name(tool_name) @@ -60,3 +80,7 @@ def test_all_json_configs_from_tool_path(self): ) assert config.name == path.stem + + def test_get_missing_tool(self): + with self.assertRaises(ValidationError): + ToolConfig.from_tool_name("missing_tool") \ No newline at end of file