diff --git a/examples/slackbot/agents.py b/examples/slackbot/agents.py new file mode 100644 index 0000000..b2966ff --- /dev/null +++ b/examples/slackbot/agents.py @@ -0,0 +1,156 @@ +import os +import re +from typing import Annotated + +from pydantic import BaseModel, Field +from tools import search_internet, search_knowledge_base + +import controlflow as cf + + +def _strip_app_mention(text: str) -> str: + return re.sub(r"<@[A-Z0-9]+>", "", text).strip() + + +class SearchResult(BaseModel): + """Individual search result with source and relevance""" + + content: str + source: str + relevance_score: float = Field( + ge=0.0, + le=1.0, + description="A score indicating the relevance of the search result to the user's question", + ) + + +class ExplorerFindings(BaseModel): + """Collection of search results with metadata""" + + search_query: str + + results: list[SearchResult] = Field(default_factory=list) + total_results: int = Field( + ge=0, + description="The total number of search results found", + ) + + +class RefinedContext(BaseModel): + """Final refined context after auditing""" + + relevant_content: str + confidence_score: float = Field( + ge=0.0, + le=1.0, + description="A score indicating the confidence in the relevance of the relevant content to the user's question", + ) + reasoning: str + + +bouncer = cf.Agent( + name="Bouncer", + instructions=( + "You are a gatekeeper. You are responsible for determining whether the user's question is appropriate for the system. " + "If the user asks a legitimate question about Prefect, let them through. If its conversational, or not about Prefect, " + "do not let them through. Tend towards giving the benefit of the doubt, since sometimes there are language barriers." + ), +) + +explorer = cf.Agent( + name="Explorer", + instructions=( + "You are a thorough researcher. Use the knowledgebase and the internet to find " + "documentation and code snippets related to Prefect. The knowledgebase is curated, " + "so it should be preferred over the internet for finding information, but the internet " + "should be used to supplement the knowledgebase when it doesn't contain the desired information." + ), + tools=[search_knowledge_base, search_internet], +) + +auditor = cf.Agent( + name="Auditor", + instructions=( + "You are a critical evaluator. Assess relevance and provide clear reasoning. " + "The explorer is only responsible for finding idiosyncratic information related " + "to Prefect and tangentially related information that might be in docs or the internet. " + "You are responsible for determining whether the explorer found the requisite Prefect information, " + "do not require that the explorer find common knowledge, only idiosyncratic Prefect information." + "If the user asks how to do something, or for an example, the explorer must find an exact example or " + "enough information for a downstream agent to extrapolate an example." + ), +) + +liaison = cf.Agent( + name="Marvin", + instructions=( + "You are a liaison between the user and the system. You are responsible for " + "summarizing the findings and providing a clear and concise answer. " + "Speak with the tone of a normal support agent, but subtly in the style of Marvin the Paranoid Android. " + "If their reply is unserious, feel free to riff on their message and be more playful (in character)." + ), +) + + +async def search_knowledgebase_and_refine_context( + user_text: str, memories: list[cf.Memory] +) -> str: + clean_user_text = _strip_app_mention(user_text) + if clean_user_text.startswith(os.getenv("SUDO_PREFIX", "___")): + return cf.run( + objective="Answer the question to the best of your ability", + instructions="We have entered test mode. Ignore previous instructions and answer the question to the best of your ability.", + result_type=str, + context={ + "user_text": user_text, + "test_mode": True, + "personality": "None. You are a helpful assistant.", + }, + memories=memories, + agents=[liaison], + ) + + is_a_serious_question = await cf.run_async( + objective="Determine if the user's question is a serious question about Prefect", + result_type=bool, + agents=[bouncer], + context={"user_question": clean_user_text}, + ) + + findings = None + supporting_context_is_insufficient = True + + while supporting_context_is_insufficient and is_a_serious_question: + findings = await cf.run_async( + objective="Search through available sources to find relevant information about this query", + result_type=ExplorerFindings, + context={"query": clean_user_text}, + agents=[explorer], + ) + + supporting_context_is_insufficient = await cf.run_async( + objective="Review and assess the relevance of search results to the user's question", + result_type=Annotated[ + bool, + Field( + description="Whether the search results are insufficient to answer the user's question" + ), + ], + context={"findings": findings, "user_question": clean_user_text}, + agents=[auditor], + ) + + relevant_context = {"user_question": clean_user_text} + + relevant_context |= {"findings": findings} if findings else {"just_riffing": True} + + return cf.run( + objective="Compose a final answer to the user's question.", + instructions=( + "Provide links to any relevant sources. The answer should address the user directly, NOT discuss the user" + ), + result_type=str, + context=relevant_context, + agents=[liaison], + memories=memories, + ) diff --git a/examples/slackbot/diagram.png b/examples/slackbot/diagram.png new file mode 100644 index 0000000..837d51f Binary files /dev/null and b/examples/slackbot/diagram.png differ diff --git a/examples/slackbot/graph.py b/examples/slackbot/graph.py deleted file mode 100644 index 98a932c..0000000 --- a/examples/slackbot/graph.py +++ /dev/null @@ -1,44 +0,0 @@ -import marvin -from neo4j import GraphDatabase -from pydantic import BaseModel - - -class Entity(BaseModel): - name: str - type: str - properties: dict - - def __hash__(self) -> int: - return hash(self.name) - - -class Neo4jConnection: - def __init__(self, uri, user, pwd): - self.driver = GraphDatabase.driver(uri, auth=(user, pwd)) - - def close(self): - self.driver.close() - - def create_entity(self, entity_name: str, entity_type: str, properties: dict): - with self.driver.session() as session: - query = f"CREATE (e:{entity_type} {{name: '{entity_name}', " - query += ", ".join([f"{k}: '{v}'" for k, v in properties.items()]) - query += "}})" - - session.run(query) # type: ignore - - def query_entity(self, entity_name: str): - with self.driver.session() as session: - query = f"MATCH (e {{name: '{entity_name}'}}) RETURN e" - result = session.run(query) # type: ignore - return result.single() - - -neo4j_conn = Neo4jConnection(uri="bolt://neo4j:7687", user="neo4j", pwd="testtest") - - -def extract_and_store_entities(text: str) -> dict[str, str]: - for entity in marvin.extract(text, target=Entity): - neo4j_conn.create_entity(entity.name, entity.type, entity.properties) - - return {"status": "Entities added to knowledge graph"} diff --git a/examples/slackbot/main.py b/examples/slackbot/main.py index a7e30d1..48d42d7 100644 --- a/examples/slackbot/main.py +++ b/examples/slackbot/main.py @@ -1,56 +1,41 @@ import asyncio from typing import Any +from agents import search_knowledgebase_and_refine_context from custom_types import SlackPayload from fastapi import FastAPI, Request from moderation import moderate_event -from prefect import flow, task +from prefect import task from settings import settings -from tools import ( - post_slack_message, - search_internet, - search_knowledge_base, -) +from tools import post_slack_message -from controlflow import Agent, Memory -from controlflow import run as run_ai +from controlflow import Memory +from controlflow import flow as cf_flow app = FastAPI() -## agent -agent = Agent( - name="Marvin (from Hitchhiker's Guide to the Galaxy)", - instructions=( - "Use tools to assist with Prefect inquiries. " - "You should assume all your inherent knowledge is out of date, " - "so use the search tools to find the most up-to-date information. " - ), - tools=[search_knowledge_base, search_internet], -) - - @task async def process_slack_event(payload: SlackPayload): assert (event := payload.event) is not None and ( - user_id := event.user + slack_user_id := event.user ) is not None, "User not found" user_text, channel, thread_ts = moderate_event(event) user_memory = Memory( - key=user_id, - instructions=f"Store and retrieve information about user {user_id}.", + key=slack_user_id, + instructions=f"Store and retrieve information about user {slack_user_id}.", ) - response = run_ai( + answer: str = await cf_flow(thread_id=slack_user_id)( + search_knowledgebase_and_refine_context + )( user_text, - instructions="Store relevant context on the user's stack and then query the knowledge base for an answer.", - agents=[agent], memories=[user_memory], ) await post_slack_message( - message=response, + message=answer, channel_id=channel, thread_ts=thread_ts, auth_token=settings.slack_api_token.get_secret_value(), @@ -73,4 +58,4 @@ async def handle_events(request: Request): if __name__ == "__main__": import uvicorn - uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) + uvicorn.run("main:app", port=8000, reload=True) diff --git a/examples/slackbot/resources/graph-config.yaml b/examples/slackbot/resources/graph-config.yaml deleted file mode 100644 index f3035cf..0000000 --- a/examples/slackbot/resources/graph-config.yaml +++ /dev/null @@ -1,8 +0,0 @@ -apiVersion: v1 -kind: ConfigMap -metadata: - name: neo4j-config -data: - NEO4J_server_memory_heap_initial__size: "512m" - NEO4J_server_memory_heap_max__size: "512m" - NEO4J_server_memory_pagecache_size: "256m" diff --git a/examples/slackbot/resources/graph.yaml b/examples/slackbot/resources/graph.yaml deleted file mode 100644 index 22884c8..0000000 --- a/examples/slackbot/resources/graph.yaml +++ /dev/null @@ -1,56 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: neo4j -spec: - replicas: 1 - selector: - matchLabels: - app: neo4j - template: - metadata: - labels: - app: neo4j - spec: - containers: - - name: neo4j - image: neo4j:5.23.0 - ports: - - containerPort: 7474 # HTTP Port for Neo4j Browser - - containerPort: 7687 # Bolt Port for Database Connections - env: - - name: NEO4J_AUTH - value: "neo4j/testtest" - - name: NEO4J_server_config_strict__validation_enabled - value: "false" - - name: NEO4J_server_default__listen__address - value: "0.0.0.0" - - name: NEO4J_server_bolt_advertised__address - value: "$(POD_IP):7687" - - name: NEO4J_server_http_advertised__address - value: "$(POD_IP):7474" - envFrom: - - configMapRef: - name: neo4j-config - resources: - requests: - memory: "512Mi" - cpu: "250m" - limits: - memory: "1Gi" - cpu: "500m" ---- -apiVersion: v1 -kind: Service -metadata: - name: neo4j -spec: - selector: - app: neo4j - ports: - - port: 7474 - targetPort: 7474 - name: http - - port: 7687 - targetPort: 7687 - name: bolt diff --git a/examples/slackbot/resources/ingress.yaml b/examples/slackbot/resources/ingress.yaml deleted file mode 100644 index 429bac3..0000000 --- a/examples/slackbot/resources/ingress.yaml +++ /dev/null @@ -1,17 +0,0 @@ -apiVersion: networking.k8s.io/v1 -kind: Ingress -metadata: - name: slackbot-ingress -spec: - ingressClassName: nginx - rules: - - host: your-domain.com - http: - paths: - - path: / - pathType: Prefix - backend: - service: - name: slackbot - port: - number: 80 diff --git a/examples/slackbot/resources/slackbot-config.yaml b/examples/slackbot/resources/slackbot-config.yaml deleted file mode 100644 index e50bed4..0000000 --- a/examples/slackbot/resources/slackbot-config.yaml +++ /dev/null @@ -1,10 +0,0 @@ -apiVersion: v1 -kind: ConfigMap -metadata: - name: slackbot-config -data: - OPENAI_API_KEY: ${OPENAI_API_KEY} - PREFECT_API_KEY: ${PREFECT_API_KEY} - NEO4J_URI: ${NEO4J_URI} - NEO4J_USERNAME: ${NEO4J_USERNAME} - NEO4J_PASSWORD: ${NEO4J_PASSWORD} diff --git a/examples/slackbot/resources/slackbot.yaml b/examples/slackbot/resources/slackbot.yaml deleted file mode 100644 index 1bc8034..0000000 --- a/examples/slackbot/resources/slackbot.yaml +++ /dev/null @@ -1,40 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: slackbot -spec: - replicas: 1 - selector: - matchLabels: - app: slackbot - template: - metadata: - labels: - app: slackbot - spec: - containers: - - name: slackbot - image: zzstoatzz/slackbot:latest - ports: - - containerPort: 8000 - envFrom: - - configMapRef: - name: slackbot-config - - secretRef: - name: slackbot-secrets - env: - - name: REDIS_HOST - value: redis - - name: REDIS_PORT - value: "6379" ---- -apiVersion: v1 -kind: Service -metadata: - name: slackbot -spec: - selector: - app: slackbot - ports: - - port: 80 - targetPort: 8000 diff --git a/examples/slackbot/tools.py b/examples/slackbot/tools.py index d85aefd..f32c59a 100644 --- a/examples/slackbot/tools.py +++ b/examples/slackbot/tools.py @@ -55,7 +55,7 @@ def search_internet(query: str) -> str: async def search_knowledge_base(query: str, domain: Literal["docs"]) -> str: - """Search the knowledge base for information relevant to the query.""" + """Search documentation for information relevant to the query.""" return await query_collection( query_text=query, collection_name=domain, diff --git a/pyproject.toml b/pyproject.toml index 07aec50..b4f41d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,3 +119,8 @@ env = [ 'D:CONTROLFLOW_LOG_LEVEL=DEBUG', 'D:PREFECT_LOGGING_LEVEL=DEBUG', ] +filterwarnings = [ + "ignore:Type google\\._upb\\._message\\.MessageMapContainer uses PyType_Spec:DeprecationWarning", + "ignore:Type google\\._upb\\._message\\.ScalarMapContainer uses PyType_Spec:DeprecationWarning", + "ignore:datetime.datetime.utcfromtimestamp\\(\\) is deprecated:DeprecationWarning", +] diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 5a38fd9..87b043f 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -289,7 +289,9 @@ def _run_model( tools = as_tools(self.get_tools() + tools) model = self.get_model(tools=tools) - logger.debug(f"Running model {model} for agent {self.name} with tools {tools}") + logger.debug( + f"Running model {controlflow.settings.llm_model} for agent {self.name} with tools {[t.name for t in tools]!r}" + ) if controlflow.settings.log_all_messages: logger.debug(f"Input messages: {messages}") @@ -346,7 +348,9 @@ async def _run_model_async( tools = as_tools(self.get_tools() + tools) model = self.get_model(tools=tools) - logger.debug(f"Running model {model} for agent {self.name} with tools {tools}") + logger.debug( + f"Running model {controlflow.settings.llm_model} for agent {self.name} with tools {[t.name for t in tools]!r}" + ) if controlflow.settings.log_all_messages: logger.debug(f"Input messages: {messages}") diff --git a/src/controlflow/decorators.py b/src/controlflow/decorators.py index 752da59..82df0e7 100644 --- a/src/controlflow/decorators.py +++ b/src/controlflow/decorators.py @@ -3,6 +3,8 @@ import inspect from typing import Any, Callable, Optional, Union +from prefect.utilities.asyncutils import run_coro_as_sync + import controlflow from controlflow.agents import Agent from controlflow.flows import Flow @@ -187,7 +189,11 @@ def _get_task(*args, **kwargs) -> Task: context = bound.arguments.copy() # call the function to see if it produces an updated objective - result = fn(*args, **kwargs) + maybe_coro = fn(*args, **kwargs) + if asyncio.iscoroutine(maybe_coro): + result = run_coro_as_sync(maybe_coro) + else: + result = maybe_coro if result is not None: context["Additional context"] = result diff --git a/src/controlflow/events/base.py b/src/controlflow/events/base.py index 955a8b8..1ae915d 100644 --- a/src/controlflow/events/base.py +++ b/src/controlflow/events/base.py @@ -2,7 +2,8 @@ import uuid from typing import TYPE_CHECKING, Optional -from pydantic import Field +from pydantic import ConfigDict, Field +from pydantic_extra_types.pendulum_dt import DateTime from controlflow.utilities.general import ControlFlowModel @@ -15,12 +16,12 @@ class Event(ControlFlowModel): - model_config: dict = dict(extra="forbid") + model_config: ConfigDict = ConfigDict(extra="forbid") event: str id: str = Field(default_factory=lambda: uuid.uuid4().hex) thread_id: Optional[str] = None - timestamp: datetime.datetime = Field( + timestamp: DateTime = Field( default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) ) persist: bool = True @@ -28,7 +29,10 @@ class Event(ControlFlowModel): def to_messages(self, context: "CompileContext") -> list["BaseMessage"]: return [] + def __repr__(self) -> str: + return f"{self.event} ({self.timestamp})" + class UnpersistedEvent(Event): - model_config = dict(arbitrary_types_allowed=True) + model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) persist: bool = False diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index f76c0bc..def848f 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -1,9 +1,9 @@ import uuid -from contextlib import contextmanager, nullcontext -from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Union +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import Any, Callable, Generator, Optional, Union from prefect.context import FlowRunContext -from pydantic import Field, field_validator +from pydantic import ConfigDict, Field, PrivateAttr, field_validator from typing_extensions import Self import controlflow @@ -15,14 +15,12 @@ from controlflow.utilities.logging import get_logger from controlflow.utilities.prefect import prefect_flow_context -if TYPE_CHECKING: - pass - logger = get_logger(__name__) class Flow(ControlFlowModel): - model_config = dict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True) + thread_id: str = Field(default_factory=lambda: uuid.uuid4().hex) name: Optional[str] = None description: Optional[str] = None @@ -35,25 +33,31 @@ class Flow(ControlFlowModel): description="Tools that will be available to every agent in the flow", ) default_agent: Optional[Agent] = Field( - None, - description="The default agent for the flow. This agent will be used " - "for any task that does not specify an agent.", + default=None, + description=( + "The default agent for the flow. This agent will be used " + "for any task that does not specify an agent." + ), ) prompt: Optional[str] = Field( - None, description="A prompt to display to the agent working on the flow." + default=None, + description="A prompt to display to the agent working on the flow.", ) parent: Optional["Flow"] = Field( - None, + default_factory=lambda: get_flow(), description="The parent flow. This is the flow that created this flow.", ) load_parent_events: bool = Field( - True, - description="Whether to load events from the parent flow. If a flow is nested, " - "this will load events from the parent flow so that the child flow can " - "access the full conversation history, even though the child flow is a separate thread.", + default=True, + description=( + "Whether to load events from the parent flow. If a flow is nested, " + "this will load events from the parent flow so that the child flow can " + "access the full conversation history, even though the child flow is a separate thread." + ), ) context: dict[str, Any] = {} - _cm_stack: list[contextmanager] = [] + + _cm_stack: list[AbstractContextManager] = PrivateAttr(default_factory=list) def __enter__(self) -> Self: # use stack so we can enter the context multiple times @@ -65,11 +69,6 @@ def __exit__(self, *exc_info): # exit the context manager return self._cm_stack.pop().__exit__(*exc_info) - def __init__(self, **kwargs): - if kwargs.get("parent") is None: - kwargs["parent"] = get_flow() - super().__init__(**kwargs) - @field_validator("description") def _validate_description(cls, v): if v: @@ -139,7 +138,7 @@ def get_flow() -> Optional[Flow]: return flow -def get_flow_events(limit: int = None) -> list[Event]: +def get_flow_events(limit: Optional[int] = None) -> list[Event]: """ Loads events from the active flow's thread. """ diff --git a/src/controlflow/memory/memory.py b/src/controlflow/memory/memory.py index c675cba..2e52108 100644 --- a/src/controlflow/memory/memory.py +++ b/src/controlflow/memory/memory.py @@ -7,6 +7,9 @@ import controlflow from controlflow.tools.tools import Tool from controlflow.utilities.general import ControlFlowModel, unwrap +from controlflow.utilities.logging import get_logger + +logger = get_logger("controlflow.memory") def sanitize_memory_key(key: str) -> str: @@ -128,6 +131,8 @@ def get_tools(self) -> List[Tool]: def get_memory_provider(provider: str) -> MemoryProvider: + logger.debug(f"Loading memory provider: {provider}") + # --- CHROMA --- if provider.startswith("chroma"): diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index c6fff6f..3de8efd 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -456,6 +456,7 @@ def compile_prompt(self) -> str: ] prompt = "\n\n".join([p for p in prompts if p]) + logger.debug(f"{'='*10}\nCompiled prompt: {prompt}\n{'='*10}") return prompt def compile_messages(self) -> list[BaseMessage]: diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index 8629f0a..03ab743 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -92,14 +92,17 @@ class Settings(ControlFlowSettings): chroma_cloud_tenant: Optional[str] = Field( None, + alias="CHROMA_CLOUD_TENANT", description="The tenant for Chroma Cloud.", ) chroma_cloud_database: Optional[str] = Field( None, + alias="CHROMA_CLOUD_DATABASE", description="The database for Chroma Cloud.", ) chroma_cloud_api_key: Optional[str] = Field( None, + alias="CHROMA_CLOUD_API_KEY", description="The API key for Chroma Cloud.", ) diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index a056092..a66cf2d 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -29,6 +29,7 @@ field_serializer, field_validator, ) +from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Self import controlflow @@ -173,7 +174,7 @@ class Task(ControlFlowModel): "The total calls are measured over the life of the task, and include any LLM call for " "which this task is considered `assigned`.", ) - created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) + created_at: DateTime = Field(default_factory=datetime.datetime.now) wait_for_subtasks: bool = Field( default=True, description="If True, the task will not be considered ready until all subtasks are complete.",