Skip to content

Commit

Permalink
chore: make posthog optional (use mock if not env not set)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nandan committed Feb 10, 2025
1 parent b0e1086 commit ebedbad
Show file tree
Hide file tree
Showing 16 changed files with 228 additions and 102 deletions.
12 changes: 9 additions & 3 deletions app/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from app.modules.parsing.graph_construction.parsing_controller import ParsingController
from app.modules.parsing.graph_construction.parsing_schema import ParsingRequest
from app.modules.utils.APIRouter import APIRouter
from app.core.dependencies import get_analytics_service, AnalyticsService

router = APIRouter()

Expand Down Expand Up @@ -55,6 +56,7 @@ async def get_api_key_user(
async def create_conversation(
conversation: SimpleConversationRequest,
db: Session = Depends(get_db),
analytics_service: AnalyticsService = Depends(get_analytics_service),
user=Depends(get_api_key_user),
):
user_id = user["user_id"]
Expand All @@ -67,17 +69,20 @@ async def create_conversation(
agent_ids=conversation.agent_ids,
)

controller = ConversationController(db, user_id, None)
controller = ConversationController(db, user_id, None, analytics_service)
return await controller.create_conversation(full_request)


@router.post("/parse")
async def parse_directory(
repo_details: ParsingRequest,
db: Session = Depends(get_db),
analytics_service: AnalyticsService = Depends(get_analytics_service),
user=Depends(get_api_key_user),
):
return await ParsingController.parse_directory(repo_details, db, user)
return await ParsingController.parse_directory(
repo_details, db, user, analytics_service
)


@router.get("/parsing-status/{project_id}")
Expand All @@ -94,13 +99,14 @@ async def post_message(
conversation_id: str,
message: MessageRequest,
db: Session = Depends(get_db),
analytics_service: AnalyticsService = Depends(get_analytics_service),
user=Depends(get_api_key_user),
):
if message.content == "" or message.content is None or message.content.isspace():
raise HTTPException(status_code=400, detail="Message content cannot be empty")

user_id = user["user_id"]
# Note: email is no longer available with API key auth
controller = ConversationController(db, user_id, None)
controller = ConversationController(db, user_id, None, analytics_service)
message_stream = controller.post_message(conversation_id, message, stream=False)
return StreamingResponse(message_stream, media_type="text/event-stream")
4 changes: 3 additions & 1 deletion app/celery/tasks/parsing_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from app.core.database import SessionLocal
from app.modules.parsing.graph_construction.parsing_schema import ParsingRequest
from app.modules.parsing.graph_construction.parsing_service import ParsingService
from app.core.dependencies import get_analytics_service

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,7 +43,8 @@ def process_parsing(
) -> None:
logger.info(f"Task received: Starting parsing process for project {project_id}")
try:
parsing_service = ParsingService(self.db, user_id)
analytics_service = get_analytics_service()
parsing_service = ParsingService(self.db, user_id, analytics_service)

async def run_parsing():
import time
Expand Down
42 changes: 42 additions & 0 deletions app/core/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from app.modules.utils.analytics_service import (
AnalyticsService,
PosthogAnalyticsService,
MockAnalyticsService,
)
from dotenv import load_dotenv
from starlette.datastructures import State
from fastapi import Request
import logging
import os

logger = logging.getLogger(__name__)


# Analytics service configuration


def init_analytics_service() -> AnalyticsService:
POSTHOG_API_KEY = os.getenv("POSTHOG_API_KEY") or ""
POSTHOG_HOST = os.getenv("POSTHOG_HOST") or ""

if POSTHOG_API_KEY != "" and POSTHOG_HOST != "":
logger.info(f"using PostHog analytics service with host {POSTHOG_HOST}")
return PosthogAnalyticsService(POSTHOG_API_KEY, POSTHOG_HOST)

logger.info(
"no AnalyticsService envs found (POSTHOG_API_KEY, POSTHOG_HOST), using MockAnalytics service"
)
return MockAnalyticsService()


def get_analytics_service(request: Request) -> AnalyticsService:
return request.app.state.analytics_service


# State initialization


def init_state(state: State):
load_dotenv(override=True)
state.analytics_service = init_analytics_service()
return state
11 changes: 10 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,21 @@
from app.modules.users.user_router import router as user_router
from app.modules.users.user_service import UserService
from app.modules.utils.firebase_setup import FirebaseSetup
from contextlib import asynccontextmanager
from app.core.dependencies import init_state, init_analytics_service

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


@asynccontextmanager
async def lifespan(app: FastAPI):
app.state = init_state(app.state)
yield
# service closing logic goes here


class MainApp:
def __init__(self):
load_dotenv(override=True)
Expand All @@ -49,7 +58,7 @@ def __init__(self):
)
exit(1)
self.setup_sentry()
self.app = FastAPI()
self.app = FastAPI(lifespan=lifespan)
self.setup_cors()
self.initialize_database()
self.check_and_set_env_vars()
Expand Down
12 changes: 8 additions & 4 deletions app/modules/auth/auth_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from app.modules.users.user_schema import CreateUser
from app.modules.users.user_service import UserService
from app.modules.utils.APIRouter import APIRouter
from app.modules.utils.posthog_helper import PostHogClient
from app.core.dependencies import get_analytics_service, AnalyticsService

SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL", None)

Expand All @@ -41,7 +41,11 @@ async def login(login_request: LoginRequest):
return JSONResponse(content={"error": f"ERROR: {str(e)}"}, status_code=400)

@auth_router.post("/signup")
async def signup(request: Request, db: Session = Depends(get_db)):
async def signup(
request: Request,
db: Session = Depends(get_db),
analytics_service: AnalyticsService = Depends(get_analytics_service),
):
body = json.loads(await request.body())
uid = body["uid"]
oauth_token = body["accessToken"]
Expand Down Expand Up @@ -73,8 +77,8 @@ async def signup(request: Request, db: Session = Depends(get_db)):
f"New signup: {body['email']} ({body['displayName']})"
)

PostHogClient().send_event(
uid,
analytics_service.capture_event(
f"{uid}",
"signup_event",
{
"email": body["email"],
Expand Down
14 changes: 12 additions & 2 deletions app/modules/conversations/conversation/conversation_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,23 @@
MessageResponse,
NodeContext,
)
from app.core.dependencies import AnalyticsService


class ConversationController:
def __init__(self, db: Session, user_id: str, user_email: str):
def __init__(
self,
db: Session,
user_id: str,
user_email: str,
analytics_service: AnalyticsService,
):
self.user_email = user_email
self.service = ConversationService.create(db, user_id, user_email)
self.service = ConversationService.create(
db, user_id, user_email, analytics_service
)
self.user_id = user_id
self.analytics_service = analytics_service

async def create_conversation(
self, conversation: CreateConversationRequest
Expand Down
56 changes: 40 additions & 16 deletions app/modules/conversations/conversation/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from app.modules.projects.projects_service import ProjectService
from app.modules.users.user_service import UserService
from app.modules.utils.posthog_helper import PostHogClient
from app.core.dependencies import AnalyticsService

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(self, db, provider_service):
self.db = db
self.provider_service = provider_service
self.agent = None
self.current_agent_id = None
self.current_agent_id = None
self.classifier = None
self.agents_service = AgentsService(db)
self.agent_factory = AgentFactory(db, provider_service)
Expand Down Expand Up @@ -148,17 +148,24 @@ async def classifier_node(self, state: State) -> Command:
return Command(update={"response": "No query provided"}, goto=END)

agent_list = {agent.id: agent.status for agent in self.available_agents}

# First check - if this is a custom agent (non-SYSTEM), route directly
if state["agent_id"] in agent_list and agent_list[state["agent_id"]] != "SYSTEM":
if (
state["agent_id"] in agent_list
and agent_list[state["agent_id"]] != "SYSTEM"
):
# Initialize the agent if needed
if not self.agent or self.current_agent_id != state["agent_id"]:
try:
self.agent = self.agent_factory.get_agent(state["agent_id"], state["user_id"])
self.agent = self.agent_factory.get_agent(
state["agent_id"], state["user_id"]
)
self.current_agent_id = state["agent_id"]
except Exception as e:
logger.error(f"Failed to create agent {state['agent_id']}: {e}")
return Command(update={"response": "Failed to initialize agent"}, goto=END)
return Command(
update={"response": "Failed to initialize agent"}, goto=END
)
return Command(update={"agent_id": state["agent_id"]}, goto="agent_node")

# For system agents, perform classification
Expand All @@ -167,25 +174,33 @@ async def classifier_node(self, state: State) -> Command:
agent_id=state["agent_id"],
agent_descriptions=self.agent_descriptions,
)

response = await self.llm.ainvoke(prompt)
response = response.content.strip("`")
try:
agent_id, confidence = response.split("|")
confidence = float(confidence)
selected_agent_id = agent_id if confidence >= 0.5 and agent_id in agent_list else state["agent_id"]
selected_agent_id = (
agent_id
if confidence >= 0.5 and agent_id in agent_list
else state["agent_id"]
)
except (ValueError, TypeError):
logger.error("Classification format error, falling back to current agent")
selected_agent_id = state["agent_id"]

# Initialize the selected system agent
if not self.agent or self.current_agent_id != selected_agent_id:
try:
self.agent = self.agent_factory.get_agent(selected_agent_id, state["user_id"])
self.agent = self.agent_factory.get_agent(
selected_agent_id, state["user_id"]
)
self.current_agent_id = selected_agent_id
except Exception as e:
logger.error(f"Failed to create agent {selected_agent_id}: {e}")
return Command(update={"response": "Failed to initialize agent"}, goto=END)
return Command(
update={"response": "Failed to initialize agent"}, goto=END
)

logger.info(
f"Streaming AI response for conversation {state['conversation_id']} "
Expand All @@ -198,7 +213,7 @@ async def agent_node(self, state: State, writer: StreamWriter):
if not self.agent:
logger.error("Agent not initialized before agent_node execution")
return Command(update={"response": "Agent not initialized"}, goto=END)

try:
async for chunk in self.agent.run(
query=state["query"],
Expand Down Expand Up @@ -263,6 +278,7 @@ def __init__(
provider_service: ProviderService,
agent_injector_service: AgentInjectorService,
custom_agent_service: CustomAgentsService,
analytics_service: AnalyticsService,
):
self.sql_db = db
self.user_id = user_id
Expand All @@ -272,9 +288,16 @@ def __init__(
self.provider_service = provider_service
self.agent_injector_service = agent_injector_service
self.custom_agent_service = custom_agent_service
self.analytics_service = analytics_service

@classmethod
def create(cls, db: Session, user_id: str, user_email: str):
def create(
cls,
db: Session,
user_id: str,
user_email: str,
analytics_service: AnalyticsService,
):
project_service = ProjectService(db)
history_manager = ChatHistoryService(db)
provider_service = ProviderService(db, user_id)
Expand All @@ -289,6 +312,7 @@ def create(cls, db: Session, user_id: str, user_email: str):
provider_service,
agent_injector_service,
custom_agent_service,
analytics_service,
)

async def check_conversation_access(
Expand Down Expand Up @@ -395,7 +419,7 @@ def _create_conversation_record(
f"Project id : {conversation.project_ids[0]} Created new conversation with ID: {conversation_id}, title: {title}, user_id: {user_id}, agent_id: {conversation.agent_ids[0]}"
)
provider_name = self.provider_service.get_llm_provider_name()
PostHogClient().send_event(
self.analytics_service.capture_event(
user_id,
"create Conversation Event",
{
Expand Down Expand Up @@ -452,7 +476,7 @@ async def store_message(
logger.info(f"Stored message in conversation {conversation_id}")
provider_name = self.provider_service.get_llm_provider_name()

PostHogClient().send_event(
self.analytics_service.capture_event(
user_id,
"message post event",
{"conversation_id": conversation_id, "llm": provider_name},
Expand Down Expand Up @@ -583,7 +607,7 @@ async def regenerate_last_message(
await self._archive_subsequent_messages(
conversation_id, last_human_message.created_at
)
PostHogClient().send_event(
self.analytics_service.capture_event(
user_id,
"regenerate_conversation_event",
{"conversation_id": conversation_id},
Expand Down Expand Up @@ -740,7 +764,7 @@ async def delete_conversation(self, conversation_id: str, user_id: str) -> dict:
# If we get here, commit the transaction
self.sql_db.commit()

PostHogClient().send_event(
self.analytics_service.capture_event(
user_id,
"delete_conversation_event",
{"conversation_id": conversation_id},
Expand Down
Loading

0 comments on commit ebedbad

Please sign in to comment.