From 314b30766ea9b0b57b8164466b25e4f9ba6f4f9b Mon Sep 17 00:00:00 2001 From: Harvey Date: Wed, 14 Jan 2026 16:03:11 +0800 Subject: [PATCH 01/11] feat: Add pagination support to consume records retrieval --- service/app/api/v1/redemption.py | 20 ++++++++--- service/app/repos/consume.py | 5 +-- web/src/components/admin/TrendChartTab.tsx | 2 +- web/src/components/admin/UserRankingsTab.tsx | 2 +- web/src/components/layouts/XyzenAgent.tsx | 1 - web/src/service/redemptionService.ts | 38 ++++++++++++++++++++ 6 files changed, 58 insertions(+), 10 deletions(-) diff --git a/service/app/api/v1/redemption.py b/service/app/api/v1/redemption.py index 33a8bace..0c465fd4 100644 --- a/service/app/api/v1/redemption.py +++ b/service/app/api/v1/redemption.py @@ -672,6 +672,7 @@ async def get_consume_records( end_date: Optional[str] = None, tz: Optional[str] = None, limit: int = 10000, + offset: int = 0, db: AsyncSession = Depends(get_db_session), ): """ @@ -690,7 +691,9 @@ async def get_consume_records( Returns: List of consume records """ - logger.info(f"Admin fetching consume records from {start_date} to {end_date}, tz: {tz}, limit: {limit}") + logger.info( + f"Admin fetching consume records from {start_date} to {end_date}, tz: {tz}, limit: {limit}, offset: {offset}" + ) # Verify admin secret if admin_secret != configs.Admin.secret: @@ -701,12 +704,19 @@ async def get_consume_records( ) try: + if offset < 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="offset must be >= 0", + ) + consume_repo = ConsumeRepository(db) records = await consume_repo.list_all_consume_records( - start_date, - end_date, - tz, - limit, + start_date=start_date, + end_date=end_date, + tz=tz, + limit=limit, + offset=offset, ) return [ diff --git a/service/app/repos/consume.py b/service/app/repos/consume.py index f94112ef..48543c2a 100644 --- a/service/app/repos/consume.py +++ b/service/app/repos/consume.py @@ -429,6 +429,7 @@ async def list_all_consume_records( end_date: str | None = None, tz: str | None = None, limit: int = 10000, + offset: int = 0, ) -> list[ConsumeRecord]: """ Get all consumption records with optional date filtering. @@ -444,7 +445,7 @@ async def list_all_consume_records( """ from datetime import datetime, timezone - logger.debug(f"Fetching consume records from {start_date} to {end_date}, limit: {limit}") + logger.debug(f"Fetching consume records from {start_date} to {end_date}, limit: {limit}, offset: {offset}") query = select(ConsumeRecord) @@ -469,7 +470,7 @@ async def list_all_consume_records( query = query.where(ConsumeRecord.created_at <= end_dt) # Order by creation time ascending for chronological trend analysis - query = query.order_by(ConsumeRecord.created_at.asc()).limit(limit) # type: ignore + query = query.order_by(ConsumeRecord.created_at.asc()).offset(offset).limit(limit) # type: ignore result = await self.db.exec(query) records = list(result.all()) diff --git a/web/src/components/admin/TrendChartTab.tsx b/web/src/components/admin/TrendChartTab.tsx index 74844fbe..08c617d6 100644 --- a/web/src/components/admin/TrendChartTab.tsx +++ b/web/src/components/admin/TrendChartTab.tsx @@ -34,7 +34,7 @@ export function TrendChartTab({ adminSecret }: TrendChartTabProps) { const startDate = formatInTimeZone(start, DEFAULT_TIMEZONE, "yyyy-MM-dd"); const endDate = formatInTimeZone(end, DEFAULT_TIMEZONE, "yyyy-MM-dd"); - const data = await redemptionService.getConsumeRecords( + const data = await redemptionService.getAllConsumeRecords( adminSecret, startDate, endDate, diff --git a/web/src/components/admin/UserRankingsTab.tsx b/web/src/components/admin/UserRankingsTab.tsx index 2e402725..0338ad09 100644 --- a/web/src/components/admin/UserRankingsTab.tsx +++ b/web/src/components/admin/UserRankingsTab.tsx @@ -72,7 +72,7 @@ export function UserRankingsTab({ adminSecret }: UserRankingsTabProps) { "Fetching consume records with adminSecret:", adminSecret ? "***" : "missing", ); - const data = await redemptionService.getConsumeRecords( + const data = await redemptionService.getAllConsumeRecords( adminSecret, startDate, endDate, diff --git a/web/src/components/layouts/XyzenAgent.tsx b/web/src/components/layouts/XyzenAgent.tsx index c9bfc944..8188665c 100644 --- a/web/src/components/layouts/XyzenAgent.tsx +++ b/web/src/components/layouts/XyzenAgent.tsx @@ -424,7 +424,6 @@ export default function XyzenAgent({ const allAgents = [...filteredSystemAgents, ...regularAgents]; // Clean sidebar with auto-loaded MCPs for system agents - return ( { const url = new URL( `${this.getBackendUrl()}/xyzen/api/v1/redemption/admin/stats/consume-records`, @@ -247,6 +248,7 @@ class RedemptionService { url.searchParams.append("tz", tz); } url.searchParams.append("limit", limit.toString()); + url.searchParams.append("offset", offset.toString()); const response = await fetch(url.toString(), { method: "GET", @@ -266,6 +268,42 @@ class RedemptionService { return response.json(); } + /** + * Get all consume records by auto-pagination (admin only) + * + * NOTE: Use with a date range to avoid downloading excessive data. + */ + async getAllConsumeRecords( + adminSecret: string, + startDate?: string, + endDate?: string, + tz?: string, + pageSize = 10000, + ): Promise { + const all: ConsumeRecordResponse[] = []; + let offset = 0; + + while (true) { + const page = await this.getConsumeRecords( + adminSecret, + startDate, + endDate, + tz, + pageSize, + offset, + ); + all.push(...page); + + if (page.length < pageSize) { + break; + } + + offset += pageSize; + } + + return all; + } + /** * Get user activity statistics (admin only) */ From 3e7582ecea306f3609d224c5b7562699f763749d Mon Sep 17 00:00:00 2001 From: Harvey Date: Wed, 14 Jan 2026 18:55:49 +0800 Subject: [PATCH 02/11] feat: Enhance CI/CD workflow and introduce Spatial Agent Workspace - Added beta tagging for Docker images in GitHub Actions workflow. - Updated deployment commands to use beta tags for service and web images. - Introduced a new Product Interaction Design document for the Spatial Agent Workspace. - Integrated SpatialWorkspace component into the AppFullscreen layout. - Modified ActivityBar to include a new panel for the workspace concept. - Enhanced XyzenAgent component with tooltips for marketplace published agents. - Removed unused tooltip component and cleaned up related code. - Updated localization files to include new strings for agent marketplace features. - Refactored UI slice to accommodate new workspace panel type. --- .github/workflows/prod-build.yaml | 24 +- .github/workflows/test-build.yaml | 21 +- PRODUCT_INTERACTION.md | 59 +++ web/src/app/AppFullscreen.tsx | 7 + .../animate-ui/components/radix/tooltip.tsx | 83 ---- web/src/components/layouts/ActivityBar.tsx | 13 +- web/src/components/layouts/XyzenAgent.tsx | 232 +++++++---- web/src/components/layouts/XyzenChat.tsx | 13 - web/src/components/layouts/XyzenTopics.tsx | 377 ------------------ web/src/i18n/locales/en/agents.json | 5 + web/src/i18n/locales/ja/agents.json | 5 + web/src/i18n/locales/zh/agents.json | 5 + web/src/store/slices/uiSlice/index.ts | 6 +- 13 files changed, 283 insertions(+), 567 deletions(-) create mode 100644 PRODUCT_INTERACTION.md delete mode 100644 web/src/components/animate-ui/components/radix/tooltip.tsx delete mode 100644 web/src/components/layouts/XyzenTopics.tsx diff --git a/.github/workflows/prod-build.yaml b/.github/workflows/prod-build.yaml index fabcffdd..b842a5ce 100644 --- a/.github/workflows/prod-build.yaml +++ b/.github/workflows/prod-build.yaml @@ -16,6 +16,7 @@ jobs: id: build_setup run: | echo "build_start=$(date '+%Y-%m-%d %H:%M:%S')" >> $GITHUB_OUTPUT + echo "release_tag=RELEASE.$(date -u '+%Y-%m-%dT%H-%M-%SZ')" >> $GITHUB_OUTPUT echo "commit_author=$(git log -1 --pretty=format:'%an')" >> $GITHUB_OUTPUT echo "commit_email=$(git log -1 --pretty=format:'%ae')" >> $GITHUB_OUTPUT echo "commit_message=$(git log -1 --pretty=format:'%s')" >> $GITHUB_OUTPUT @@ -35,11 +36,17 @@ jobs: - name: Build and push Service Docker image run: | - docker build service -t registry.sciol.ac.cn/sciol/xyzen-service:latest --push + docker buildx build service \ + -t registry.sciol.ac.cn/sciol/xyzen-service:latest \ + -t registry.sciol.ac.cn/sciol/xyzen-service:${{ steps.build_setup.outputs.release_tag }} \ + --push - name: Build and push Web Docker image run: | - docker build web -t registry.sciol.ac.cn/sciol/xyzen-web:latest --push + docker buildx build web \ + -t registry.sciol.ac.cn/sciol/xyzen-web:latest \ + -t registry.sciol.ac.cn/sciol/xyzen-web:${{ steps.build_setup.outputs.release_tag }} \ + --push - name: Set up kubeconfig env: @@ -51,11 +58,10 @@ jobs: - name: Rolling update deployments run: | - kubectl rollout restart deployment xyzen -n bohrium - kubectl rollout restart deployment xyzen-celery -n bohrium - kubectl rollout restart deployment xyzen -n sciol - kubectl rollout restart deployment xyzen-web -n sciol - + kubectl -n bohrium set image deployment/xyzen *=registry.sciol.ac.cn/sciol/xyzen-service:${{ steps.build_setup.outputs.release_tag }} + kubectl -n bohrium set image deployment/xyzen-celery *=registry.sciol.ac.cn/sciol/xyzen-service:${{ steps.build_setup.outputs.release_tag }} + kubectl -n sciol set image deployment/xyzen *=registry.sciol.ac.cn/sciol/xyzen-service:${{ steps.build_setup.outputs.release_tag }} + kubectl -n sciol set image deployment/xyzen-web *=registry.sciol.ac.cn/sciol/xyzen-web:${{ steps.build_setup.outputs.release_tag }} - name: Calculate build duration if: always() @@ -115,5 +121,5 @@ jobs: commit_sha: ${{ steps.build_setup.outputs.commit_sha }} commit_sha_short: ${{ steps.build_setup.outputs.commit_sha_short }} commit_date: ${{ steps.build_setup.outputs.commit_date }} - service_image: 'registry.sciol.ac.cn/sciol/xyzen-service:latest' - web_image: 'registry.sciol.ac.cn/sciol/xyzen-web:latest' + service_image: 'registry.sciol.ac.cn/sciol/xyzen-service:${{ steps.build_setup.outputs.release_tag }}' + web_image: 'registry.sciol.ac.cn/sciol/xyzen-web:${{ steps.build_setup.outputs.release_tag }}' diff --git a/.github/workflows/test-build.yaml b/.github/workflows/test-build.yaml index af06972b..085542fb 100644 --- a/.github/workflows/test-build.yaml +++ b/.github/workflows/test-build.yaml @@ -16,6 +16,7 @@ jobs: id: build_setup run: | echo "build_start=$(date '+%Y-%m-%d %H:%M:%S')" >> $GITHUB_OUTPUT + echo "beta_tag=BETA.$(date -u '+%Y-%m-%dT%H-%M-%SZ')" >> $GITHUB_OUTPUT echo "commit_author=$(git log -1 --pretty=format:'%an')" >> $GITHUB_OUTPUT echo "commit_email=$(git log -1 --pretty=format:'%ae')" >> $GITHUB_OUTPUT echo "commit_message=$(git log -1 --pretty=format:'%s')" >> $GITHUB_OUTPUT @@ -35,11 +36,17 @@ jobs: - name: Build and push Service Docker image run: | - docker build service -t registry.sciol.ac.cn/sciol/xyzen-service:test --push + docker buildx build service \ + -t registry.sciol.ac.cn/sciol/xyzen-service:test \ + -t registry.sciol.ac.cn/sciol/xyzen-service:${{ steps.build_setup.outputs.beta_tag }} \ + --push - name: Build and push Web Docker image run: | - docker build web -t registry.sciol.ac.cn/sciol/xyzen-web:test --push + docker buildx build web \ + -t registry.sciol.ac.cn/sciol/xyzen-web:test \ + -t registry.sciol.ac.cn/sciol/xyzen-web:${{ steps.build_setup.outputs.beta_tag }} \ + --push - name: Download Let's Encrypt CA run: curl -o ca.crt https://letsencrypt.org/certs/isrgrootx1.pem @@ -50,19 +57,19 @@ jobs: --server=${{ secrets.SCIENCEOL_K8S_SERVER_URL }} \ --token=${{ secrets.SCIENCEOL_K8S_ADMIN_TOKEN }} \ --certificate-authority=ca.crt \ - rollout restart deployment xyzen -n bohrium + set image deployment/xyzen -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-service:${{ steps.build_setup.outputs.beta_tag }} kubectl \ --server=${{ secrets.SCIENCEOL_K8S_SERVER_URL }} \ --token=${{ secrets.SCIENCEOL_K8S_ADMIN_TOKEN }} \ --certificate-authority=ca.crt \ - rollout restart deployment xyzen-web -n bohrium + set image deployment/xyzen-web -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-web:${{ steps.build_setup.outputs.beta_tag }} kubectl \ --server=${{ secrets.SCIENCEOL_K8S_SERVER_URL }} \ --token=${{ secrets.SCIENCEOL_K8S_ADMIN_TOKEN }} \ --certificate-authority=ca.crt \ - rollout restart deployment xyzen-celery -n bohrium + set image deployment/xyzen-celery -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-service:${{ steps.build_setup.outputs.beta_tag }} - name: Calculate build duration if: always() @@ -122,5 +129,5 @@ jobs: commit_sha: ${{ steps.build_setup.outputs.commit_sha }} commit_sha_short: ${{ steps.build_setup.outputs.commit_sha_short }} commit_date: ${{ steps.build_setup.outputs.commit_date }} - service_image: 'registry.sciol.ac.cn/sciol/xyzen-service:test' - web_image: 'registry.sciol.ac.cn/sciol/xyzen-web:test' + service_image: 'registry.sciol.ac.cn/sciol/xyzen-service:${{ steps.build_setup.outputs.beta_tag }}' + web_image: 'registry.sciol.ac.cn/sciol/xyzen-web:${{ steps.build_setup.outputs.beta_tag }}' diff --git a/PRODUCT_INTERACTION.md b/PRODUCT_INTERACTION.md new file mode 100644 index 00000000..fd1ad57e --- /dev/null +++ b/PRODUCT_INTERACTION.md @@ -0,0 +1,59 @@ +# Product Interaction Design: Spatial Agent Workspace + +## Vision + +To transform the user experience from managing a "list of tools" to entering a "digital workspace" where Agents are active, specialized assets. The interface balances high-density information with immersive focus states. + +## Core Concepts + +### 1. The Canvas (The Team View) + +Instead of a flat list, Agents exist on an infinite 2D canvas. + +- **Visuals**: Agents are represented as "Nodes" or "Bases" rather than simple list items. +- **States**: + - **Idle**: Subtle pulsing or static. + - **Working**: Glowing, animated data streams. + - **Collaborating**: Connecting lines between agents. +- **Interaction**: Pan and zoom to explore the team. Drag agents to cluster them by function (e.g., "Creative Team", "Dev Ops"). + +### 2. The Focus (The Deep Dive) + +Transitions should be seamless, maintaining context while focusing on the task. + +- **Action**: Clicking an Agent transitions from "Map View" to "Focus View". +- **Animation**: The camera smoothly zooms in to the specific Agent node. The background (other agents) blurs but remains visible in the periphery, providing a sense of "location" within the system. +- **Feedback**: The "Chat Window" isn't a separate page; it slides out from the Agent node itself, reinforcing that you are talking _to_ that specific entity. + +### 3. The Workspace (The Chat Interface) + +The chat interface is the primary daily driver, so it expands to occupy valuable screen real estate while keeping asset context available. + +- **Layout**: + - **Left/Center (Chat)**: Wide, comfortable reading area. The primary focus. + - **Right (Context/Assets)**: Collapsible sidebar showing the Agent's specific "Memory", "Tools", and "Files". +- **Switching**: + - **Fast Switch**: A "Dock" or "Mini-map" allows jumping between recently used agents without zooming all the way out. + - **Zoom Out**: A gesture or button seamlessly pulls the camera back to the Canvas view to see the whole team. + +## User Journey: "Hiring to Commanding" + +1. **Enter Workspace**: User lands on the Canvas. See 5 Agents scattered. "Market Analysis" agent is glowing red (busy). +2. **Select**: User clicks "Market Analysis". +3. **Transition**: Screen zooms in. Background blurs. Chat window slides in from the right, occupying 70% of the screen. +4. **Engage**: User chats. Uploads a PDF. +5. **Multitask**: User needs "Copywriter". + - _Option A_: Zoom out (Esc), find Copywriter, Zoom in. + - _Option B (Fast)_: Click "Copywriter" from the Quick Dock. Camera pans laterally to the Copywriter node. +6. **Collaborate**: User drags a connecting line from "Market Analysis" output to "Copywriter". + +## Technical Prototype + +The accompanying `SpatialWorkspace` component demonstrates: + +- **Spatial Layout**: Absolute positioning on a scalable surface. +- **Camera Logic**: Calculating translation and scale to center a target element. +- **Immersive Transition**: CSS/Motion transitions for smooth zooming. +- **Contextual Chat**: Sidebar entry upon focus. + +This design elevates the Agent from a "row in a database" to a "teammate at a desk". diff --git a/web/src/app/AppFullscreen.tsx b/web/src/app/AppFullscreen.tsx index 67a9ac0f..1a57f569 100644 --- a/web/src/app/AppFullscreen.tsx +++ b/web/src/app/AppFullscreen.tsx @@ -6,6 +6,7 @@ import { useEffect, useState } from "react"; import { createPortal } from "react-dom"; import AgentMarketplace from "@/app/marketplace/AgentMarketplace"; +import { SpatialWorkspace } from "@/app/tmp/SpatialWorkspace"; import { ActivityBar } from "@/components/layouts/ActivityBar"; import { AppHeader } from "@/components/layouts/AppHeader"; import KnowledgeBase from "@/components/layouts/KnowledgeBase"; @@ -120,6 +121,12 @@ export function AppFullscreen({ )} + + {activePanel === "workspace-test" && ( +
+ +
+ )} diff --git a/web/src/components/animate-ui/components/radix/tooltip.tsx b/web/src/components/animate-ui/components/radix/tooltip.tsx deleted file mode 100644 index 8ffd8b75..00000000 --- a/web/src/components/animate-ui/components/radix/tooltip.tsx +++ /dev/null @@ -1,83 +0,0 @@ -import { - TooltipArrow as TooltipArrowPrimitive, - TooltipContent as TooltipContentPrimitive, - TooltipPortal as TooltipPortalPrimitive, - Tooltip as TooltipPrimitive, - TooltipProvider as TooltipProviderPrimitive, - TooltipTrigger as TooltipTriggerPrimitive, - type TooltipContentProps as TooltipContentPrimitiveProps, - type TooltipProps as TooltipPrimitiveProps, - type TooltipProviderProps as TooltipProviderPrimitiveProps, - type TooltipTriggerProps as TooltipTriggerPrimitiveProps, -} from "@/components/animate-ui/primitives/radix/tooltip"; -import { cn } from "@/lib/utils"; - -type TooltipProviderProps = TooltipProviderPrimitiveProps; - -function TooltipProvider({ - delayDuration = 0, - ...props -}: TooltipProviderProps) { - return ; -} - -type TooltipProps = TooltipPrimitiveProps & { - delayDuration?: TooltipPrimitiveProps["delayDuration"]; - title?: React.ReactNode; -}; - -function Tooltip({ - delayDuration = 0, - title, - children, - ...props -}: TooltipProps) { - return ( - - - {children} - {title} - - - ); -} - -type TooltipTriggerProps = TooltipTriggerPrimitiveProps; - -function TooltipTrigger({ ...props }: TooltipTriggerProps) { - return ; -} - -type TooltipContentProps = TooltipContentPrimitiveProps; - -function TooltipContent({ - className, - sideOffset, - children, - ...props -}: TooltipContentProps) { - return ( - - - {children} - - - - ); -} - -export { - Tooltip, - TooltipContent, - TooltipTrigger, - type TooltipContentProps, - type TooltipProps, - type TooltipTriggerProps, -}; diff --git a/web/src/components/layouts/ActivityBar.tsx b/web/src/components/layouts/ActivityBar.tsx index 7f1e39a2..affb1d71 100644 --- a/web/src/components/layouts/ActivityBar.tsx +++ b/web/src/components/layouts/ActivityBar.tsx @@ -2,11 +2,16 @@ import { ChatBubbleLeftRightIcon, FolderIcon, SparklesIcon, + Squares2X2Icon, } from "@heroicons/react/24/outline"; import { motion } from "framer-motion"; import { useTranslation } from "react-i18next"; -export type ActivityPanel = "chat" | "knowledge" | "marketplace"; +export type ActivityPanel = + | "chat" + | "knowledge" + | "marketplace" + | "workspace-test"; interface ActivityBarProps { activePanel: ActivityPanel; @@ -116,6 +121,12 @@ export const ActivityBar: React.FC = ({ label: t("app.activityBar.community"), disabled: false, }, + { + panel: "workspace-test" as ActivityPanel, + icon: Squares2X2Icon, + label: "Workspace Concept", + disabled: false, + }, ]; return ( diff --git a/web/src/components/layouts/XyzenAgent.tsx b/web/src/components/layouts/XyzenAgent.tsx index 8188665c..7bba3aea 100644 --- a/web/src/components/layouts/XyzenAgent.tsx +++ b/web/src/components/layouts/XyzenAgent.tsx @@ -1,10 +1,19 @@ "use client"; -import McpIcon from "@/assets/McpIcon"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "@/components/animate-ui/components/animate/tooltip"; import { Badge } from "@/components/base/Badge"; import { useAuth } from "@/hooks/useAuth"; -import { PencilIcon, TrashIcon } from "@heroicons/react/24/outline"; +import { + PencilIcon, + ShoppingBagIcon, + TrashIcon, +} from "@heroicons/react/24/outline"; import { motion, type Variants } from "framer-motion"; -import React, { useEffect, useRef, useState } from "react"; +import React, { useEffect, useMemo, useRef, useState } from "react"; import { useTranslation } from "react-i18next"; import AddAgentModal from "@/components/modals/AddAgentModal"; @@ -19,6 +28,7 @@ import type { Agent } from "@/types/agents"; interface AgentCardProps { agent: Agent; + isMarketplacePublished?: boolean; onClick?: (agent: Agent) => void; onEdit?: (agent: Agent) => void; onDelete?: (agent: Agent) => void; @@ -46,6 +56,7 @@ interface ContextMenuProps { onDelete: () => void; onClose: () => void; isDefaultAgent?: boolean; + isMarketplacePublished?: boolean; agent?: Agent; } @@ -56,6 +67,7 @@ const ContextMenu: React.FC = ({ onDelete, onClose, isDefaultAgent = false, + isMarketplacePublished = false, }) => { const { t } = useTranslation(); const menuRef = useRef(null); @@ -104,16 +116,42 @@ const ContextMenu: React.FC = ({ {t("agents.editAgent")} - + {isMarketplacePublished ? ( + + + + + + + + {t("agents.deleteBlockedMessage", { + defaultValue: + "This agent is published to Agent Market. Please unpublish it first, then delete it.", + })} + + + ) : ( + + )}
); }; @@ -121,10 +159,12 @@ const ContextMenu: React.FC = ({ // 详细版本-包括名字,描述,头像,标签以及GPT模型 const AgentCard: React.FC = ({ agent, + isMarketplacePublished = false, onClick, onEdit, onDelete, }) => { + const { t } = useTranslation(); const [contextMenu, setContextMenu] = useState<{ x: number; y: number; @@ -139,7 +179,6 @@ const AgentCard: React.FC = ({ const { clientX, clientY } = touch; longPressTimer.current = setTimeout(() => { - isLongPress.current = true; setContextMenu({ x: clientX, y: clientY }); // Haptic feedback (best-effort) try { @@ -237,15 +276,33 @@ const AgentCard: React.FC = ({ {agent.name} - {/* MCP servers badge */} - {agent.mcp_servers && agent.mcp_servers.length > 0 && ( - - - {agent.mcp_servers.length} - + {/* Marketplace published badge */} + {isMarketplacePublished && ( + + + + + + + + + + + + {t("agents.badges.marketplace", { + defaultValue: "Published to Marketplace", + })} + + )} {/* Knowledge set badge */} @@ -276,6 +333,7 @@ const AgentCard: React.FC = ({ onDelete={() => onDelete?.(agent)} onClose={() => setContextMenu(null)} isDefaultAgent={isDefaultAgent} + isMarketplacePublished={isMarketplacePublished} agent={agent} /> )} @@ -329,6 +387,14 @@ export default function XyzenAgent({ // Fetch marketplace listings to check if deleted agent has a published version const { data: myListings } = useMyMarketplaceListings(); + const publishedAgentIds = useMemo(() => { + const ids = new Set(); + for (const listing of myListings ?? []) { + if (listing.is_published) ids.add(listing.agent_id); + } + return ids; + }, [myListings]); + useEffect(() => { fetchAgents(); }, [fetchAgents]); @@ -425,60 +491,74 @@ export default function XyzenAgent({ // Clean sidebar with auto-loaded MCPs for system agents return ( - - {allAgents.map((agent) => ( - - ))} - - setAddModalOpen(false)} - /> - setEditModalOpen(false)} - agent={editingAgent} - /> - {agentToDelete && ( - setConfirmModalOpen(false)} - onConfirm={() => { - deleteAgent(agentToDelete.id); - setConfirmModalOpen(false); - setAgentToDelete(null); - }} - title={t("agents.deleteTitle")} - message={(() => { - const hasListing = myListings?.some( - (l) => l.agent_id === agentToDelete.id, - ); - if (hasListing) { - return t("agents.deleteConfirmWithListing", { - name: agentToDelete.name, - }); - } - return t("agents.deleteConfirm", { name: agentToDelete.name }); - })()} + {allAgents.map((agent) => ( + + ))} + + setAddModalOpen(false)} /> - )} - + setEditModalOpen(false)} + agent={editingAgent} + /> + {agentToDelete && ( + setConfirmModalOpen(false)} + onConfirm={() => { + if (publishedAgentIds.has(agentToDelete.id)) return; + deleteAgent(agentToDelete.id); + setConfirmModalOpen(false); + setAgentToDelete(null); + }} + title={ + publishedAgentIds.has(agentToDelete.id) + ? t("agents.deleteBlockedTitle", { + defaultValue: "Can't delete agent", + }) + : t("agents.deleteTitle") + } + message={ + publishedAgentIds.has(agentToDelete.id) + ? t("agents.deleteBlockedMessage", { + defaultValue: + "This agent is published to Agent Market. Please unpublish it first, then delete it.", + }) + : t("agents.deleteConfirm", { name: agentToDelete.name }) + } + confirmLabel={ + publishedAgentIds.has(agentToDelete.id) + ? t("common.ok") + : t("agents.deleteAgent") + } + cancelLabel={t("common.cancel")} + destructive={!publishedAgentIds.has(agentToDelete.id)} + /> + )} + + ); } diff --git a/web/src/components/layouts/XyzenChat.tsx b/web/src/components/layouts/XyzenChat.tsx index b2b9d9e7..a79219fa 100644 --- a/web/src/components/layouts/XyzenChat.tsx +++ b/web/src/components/layouts/XyzenChat.tsx @@ -115,19 +115,6 @@ function BaseChat({ config, historyEnabled = false }: BaseChatProps) { return (
- - {/* Add toolbar even in empty state for history access */} -
-
- -
); } diff --git a/web/src/components/layouts/XyzenTopics.tsx b/web/src/components/layouts/XyzenTopics.tsx deleted file mode 100644 index 0ad82052..00000000 --- a/web/src/components/layouts/XyzenTopics.tsx +++ /dev/null @@ -1,377 +0,0 @@ -"use client"; - -import EditableTitle from "@/components/base/EditableTitle"; -import { LoadingSpinner } from "@/components/base/LoadingSpinner"; -import ConfirmationModal from "@/components/modals/ConfirmationModal"; -import { formatTime } from "@/lib/formatDate"; -import { useXyzen } from "@/store"; -import type { ChatHistoryItem } from "@/store/types"; -import { MapPinIcon } from "@heroicons/react/20/solid"; -import { - ArchiveBoxXMarkIcon, - ChevronRightIcon, - ClockIcon, - MagnifyingGlassIcon, - PlusIcon, - TrashIcon, - UserIcon, - XMarkIcon, -} from "@heroicons/react/24/outline"; -import { motion } from "framer-motion"; -import { useEffect, useMemo, useState } from "react"; - -/** - * XyzenTopics - Topic/Session list component for fullscreen layout - * Displays topics grouped by sessions with management capabilities - */ -export default function XyzenTopics() { - const [isConfirmModalOpen, setConfirmModalOpen] = useState(false); - const [topicToDelete, setTopicToDelete] = useState( - null, - ); - const [searchQuery, setSearchQuery] = useState(""); - const [isClearConfirmOpen, setIsClearConfirmOpen] = useState(false); - - // Select fine-grained pieces to avoid re-renders on message streaming - const chatHistory = useXyzen((s) => s.chatHistory); - const chatHistoryLoading = useXyzen((s) => s.chatHistoryLoading); - const activeChatChannel = useXyzen((s) => s.activeChatChannel); - const user = useXyzen((s) => s.user); - const activateChannel = useXyzen((s) => s.activateChannel); - const togglePinChat = useXyzen((s) => s.togglePinChat); - const fetchChatHistory = useXyzen((s) => s.fetchChatHistory); - const updateTopicName = useXyzen((s) => s.updateTopicName); - const deleteTopic = useXyzen((s) => s.deleteTopic); - const createDefaultChannel = useXyzen((s) => s.createDefaultChannel); - const clearSessionTopics = useXyzen((s) => s.clearSessionTopics); - - // Subscribe only to primitive active sessionId to avoid re-renders on message changes - const activeSessionId = useXyzen((s) => - s.activeChatChannel - ? (s.channels[s.activeChatChannel]?.sessionId ?? null) - : null, - ); - - // Load chat history on mount - useEffect(() => { - void fetchChatHistory(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); - - // Check if user is logged in - const isUserLoggedIn = useMemo(() => { - const hasUser = user && (user.id || user.username); - return hasUser; - }, [user]); - - // Get current session topics - const currentSessionTopics = useMemo(() => { - if (!activeChatChannel || !activeSessionId) return []; - return chatHistory.filter((chat) => chat.sessionId === activeSessionId); - }, [activeChatChannel, activeSessionId, chatHistory]); - - // Filter topics by search query - const filteredTopics = useMemo(() => { - if (!searchQuery.trim()) { - return currentSessionTopics; - } - - const query = searchQuery.toLowerCase(); - return currentSessionTopics.filter( - (topic) => - topic.title.toLowerCase().includes(query) || - topic.lastMessage?.toLowerCase().includes(query), - ); - }, [currentSessionTopics, searchQuery]); - - // Sort topics by pinned status and update time - const sortedTopics = useMemo(() => { - return [...filteredTopics].sort((a, b) => { - if (a.isPinned && !b.isPinned) return -1; - if (!a.isPinned && b.isPinned) return 1; - const dateA = new Date(a.updatedAt); - const dateB = new Date(b.updatedAt); - return dateB.getTime() - dateA.getTime(); - }); - }, [filteredTopics]); - - // Activate a topic - const handleActivateTopic = async (topicId: string) => { - await activateChannel(topicId); - }; - - // Toggle pin status - const handleTogglePin = (e: React.MouseEvent, topicId: string) => { - e.stopPropagation(); - togglePinChat(topicId); - }; - - // Handle delete topic - const handleDeleteTopic = (e: React.MouseEvent, topic: ChatHistoryItem) => { - e.stopPropagation(); - - // Prevent deleting the last topic - if (sortedTopics.length <= 1) { - return; - } - - setTopicToDelete(topic); - setConfirmModalOpen(true); - }; - - // Confirm deletion - const confirmDelete = async () => { - if (topicToDelete) { - await deleteTopic(topicToDelete.id); - setTopicToDelete(null); - setConfirmModalOpen(false); - } - }; - - // Handle topic name update - const handleTopicNameUpdate = async (topicId: string, newName: string) => { - await updateTopicName(topicId, newName); - }; - - // Create new topic - const handleCreateNewTopic = async () => { - if (!activeChatChannel) return; - // Read agentId from state snapshot when needed to avoid subscribing to channels - const state = useXyzen.getState(); - const currentAgent = state.channels[activeChatChannel]?.agentId; - await createDefaultChannel(currentAgent); - }; - - // Clear all topics - const handleClearAllTopics = () => { - setIsClearConfirmOpen(true); - }; - - const confirmClearAll = async () => { - if (activeChatChannel && activeSessionId) { - await clearSessionTopics(activeSessionId); - setIsClearConfirmOpen(false); - } - }; - - // Render login prompt - if (!isUserLoggedIn) { - return ( -
- -

- Please Login -

-

- Login to view and manage your topics -

-
- ); - } - - // Render loading state - if (chatHistoryLoading) { - return ( -
- -
- ); - } - - // Render empty state - if (sortedTopics.length === 0 && !searchQuery) { - return ( -
- -

- No Topics Yet -

-

- Start a conversation to create your first topic -

-
- ); - } - - // Render topics list - return ( - <> -
- {/* Header */} -
-
-
-

- Topics -

-

- {currentSessionTopics.length}{" "} - {currentSessionTopics.length === 1 ? "topic" : "topics"} - {searchQuery && ` · ${sortedTopics.length} 个搜索结果`} -

-
- -
-
- - {/* Toolbar: Search and Clear */} -
-
- {/* Search Input */} -
- - setSearchQuery(e.target.value)} - className="w-full rounded-sm border border-neutral-200 bg-white py-2 pl-9 pr-3 text-sm text-neutral-800 placeholder-neutral-400 focus:border-indigo-500 focus:outline-none focus:ring-1 focus:ring-indigo-500 dark:border-neutral-700 dark:bg-neutral-900 dark:text-white dark:placeholder-neutral-500" - /> - {searchQuery && ( - - )} -
- - {/* Clear All Button */} - -
-
- - {/* Topics List */} -
- {sortedTopics.length === 0 ? ( -
- -

- 没有找到匹配的对话 -

-
- ) : ( -
- {sortedTopics.map((topic) => ( - -
handleActivateTopic(topic.id)} - > -
- {topic.isPinned && ( - - )} -
- - handleTopicNameUpdate(topic.id, newTitle) - } - textClassName="text-sm font-medium text-neutral-800 dark:text-white truncate" - className="" - /> -
- {/* {topic.id === activeChatChannel && ( - - Active - - )} */} -
- -
- {formatTime(topic.updatedAt)} - {topic.lastMessage && ( - <> - - {topic.lastMessage} - - )} -
-
- - {/* Actions */} -
- - -
- -
-
-
- ))} -
- )} -
-
- - {/* Delete Confirmation Modal */} - { - setConfirmModalOpen(false); - setTopicToDelete(null); - }} - onConfirm={confirmDelete} - title="Delete Topic" - message={`Are you sure you want to delete "${topicToDelete?.title}"? This action cannot be undone.`} - /> - - {/* Clear All Confirmation Modal */} - setIsClearConfirmOpen(false)} - onConfirm={confirmClearAll} - title="清空所有对话" - message="确定要清空当前会话的所有对话记录吗?此操作不可恢复。" - /> - - ); -} diff --git a/web/src/i18n/locales/en/agents.json b/web/src/i18n/locales/en/agents.json index f97f25fc..91b21e39 100644 --- a/web/src/i18n/locales/en/agents.json +++ b/web/src/i18n/locales/en/agents.json @@ -2,6 +2,9 @@ "add": "Add Agent", "edit": "Edit", "delete": "Delete", + "badges": { + "marketplace": "Published to Marketplace" + }, "editAgent": "Edit Agent", "deleteAgent": "Delete Agent", "addButton": "+ Add Agent", @@ -10,6 +13,8 @@ "deleteTitle": "Delete Agent", "deleteConfirm": "Are you sure you want to permanently delete agent \"{{name}}\"? This action cannot be undone.", "deleteConfirmWithListing": "⚠️ This agent has been published to the marketplace. Deleting it will also remove the published version.\n\nAre you sure you want to permanently delete agent \"{{name}}\"? This action cannot be undone.", + "deleteBlockedTitle": "Can't delete agent", + "deleteBlockedMessage": "This agent is published to Agent Market. Please unpublish it first, then delete it.", "createDescription": "Create a new AI agent with custom prompts and tools.", "updateDescription": "Update the details for your agent.", "systemDescription": "Select a pre-configured system agent with special execution capabilities.", diff --git a/web/src/i18n/locales/ja/agents.json b/web/src/i18n/locales/ja/agents.json index 6a3063ac..7c825e25 100644 --- a/web/src/i18n/locales/ja/agents.json +++ b/web/src/i18n/locales/ja/agents.json @@ -2,6 +2,9 @@ "add": "エージェントを追加", "edit": "編集", "delete": "削除", + "badges": { + "marketplace": "マーケットプレイスに公開済み" + }, "editAgent": "エージェントを編集", "deleteAgent": "エージェントを削除", "addButton": "+ エージェントを追加", @@ -10,6 +13,8 @@ "deleteTitle": "エージェントを削除", "deleteConfirm": "エージェント\"{{name}}\"を完全に削除してもよろしいですか?この操作は元に戻せません。", "deleteConfirmWithListing": "⚠️ このエージェントはマーケットプレイスに公開されています。削除すると、公開版も削除されます。\n\nエージェント\"{{name}}\"を完全に削除してもよろしいですか?この操作は元に戻せません。", + "deleteBlockedTitle": "エージェントを削除できません", + "deleteBlockedMessage": "このエージェントは Agent Market に公開されています。先にマーケットから非公開(下架)にしてから削除してください。", "createDescription": "カスタムプロンプトとツールを使用して新しいAIエージェントを作成します。", "updateDescription": "エージェントの詳細を更新します。", "systemDescription": "特殊な実行機能を持つ事前設定されたシステムエージェントを選択します。", diff --git a/web/src/i18n/locales/zh/agents.json b/web/src/i18n/locales/zh/agents.json index 7c541de0..71fe2ee7 100644 --- a/web/src/i18n/locales/zh/agents.json +++ b/web/src/i18n/locales/zh/agents.json @@ -2,6 +2,9 @@ "add": "添加助手", "edit": "编辑", "delete": "删除", + "badges": { + "marketplace": "已发布到市场" + }, "editAgent": "编辑助手", "deleteAgent": "删除助手", "addButton": "+ 添加助手", @@ -10,6 +13,8 @@ "deleteTitle": "删除助手", "deleteConfirm": "确定要永久删除助手 \"{{name}}\" 吗?此操作无法撤销。", "deleteConfirmWithListing": "⚠️ 此助手已发布到市场。删除后,市场中的发布版本也将被移除。\n\n确定要永久删除助手 \"{{name}}\" 吗?此操作无法撤销。", + "deleteBlockedTitle": "无法删除助手", + "deleteBlockedMessage": "该助手已发布到 Agent Market。请先到 Market 下架后再删除。", "createDescription": "创建一个新的 AI 助手,可以配置专属提示词和工具。", "updateDescription": "更新助手的详细信息。", "systemDescription": "选择一个预配置的系统助手,这些助手具有特殊的执行能力。", diff --git a/web/src/store/slices/uiSlice/index.ts b/web/src/store/slices/uiSlice/index.ts index 5cbaa024..85d9bf0b 100644 --- a/web/src/store/slices/uiSlice/index.ts +++ b/web/src/store/slices/uiSlice/index.ts @@ -8,7 +8,11 @@ import { type InputPosition, type LayoutStyle } from "./types"; // Ensure xyzen service is aware of the default backend on startup xyzenService.setBackendUrl(DEFAULT_BACKEND_URL); -export type ActivityPanel = "chat" | "knowledge" | "marketplace"; +export type ActivityPanel = + | "chat" + | "knowledge" + | "marketplace" + | "workspace-test"; export interface UiSlice { backendUrl: string; From fc55288ed8f7f7c31180aac200df9f7384715c5f Mon Sep 17 00:00:00 2001 From: Harvey Date: Wed, 14 Jan 2026 22:23:26 +0800 Subject: [PATCH 03/11] feat: Implement Spatial Workspace with agent management and UI enhancements --- .vscode/settings.example.json | 9 + web/src/app/AppFullscreen.tsx | 2 +- web/src/app/chat/SpatialWorkspace.tsx | 319 +++++++++++++++++++++++ web/src/app/chat/spatial/AgentNode.tsx | 85 ++++++ web/src/app/chat/spatial/FocusedView.tsx | 176 +++++++++++++ web/src/app/chat/spatial/types.ts | 15 ++ 6 files changed, 605 insertions(+), 1 deletion(-) create mode 100644 web/src/app/chat/SpatialWorkspace.tsx create mode 100644 web/src/app/chat/spatial/AgentNode.tsx create mode 100644 web/src/app/chat/spatial/FocusedView.tsx create mode 100644 web/src/app/chat/spatial/types.ts diff --git a/.vscode/settings.example.json b/.vscode/settings.example.json index 54edefea..83f6ab8f 100644 --- a/.vscode/settings.example.json +++ b/.vscode/settings.example.json @@ -64,6 +64,15 @@ "reportImplicitStringConcatenation": "none" }, + "black-formatter.args": ["--line-length", "119"], + + "flake8.args": [ + "--max-line-length", + "119", + "--ignore", + "F401 W503 F541 F841 E226" + ], + "todo-tree.highlights.defaultHighlight": { "icon": "alert", "type": "text", diff --git a/web/src/app/AppFullscreen.tsx b/web/src/app/AppFullscreen.tsx index 1a57f569..1da94aa2 100644 --- a/web/src/app/AppFullscreen.tsx +++ b/web/src/app/AppFullscreen.tsx @@ -5,8 +5,8 @@ import { restrictToVerticalAxis } from "@dnd-kit/modifiers"; import { useEffect, useState } from "react"; import { createPortal } from "react-dom"; +import { SpatialWorkspace } from "@/app/chat/SpatialWorkspace"; import AgentMarketplace from "@/app/marketplace/AgentMarketplace"; -import { SpatialWorkspace } from "@/app/tmp/SpatialWorkspace"; import { ActivityBar } from "@/components/layouts/ActivityBar"; import { AppHeader } from "@/components/layouts/AppHeader"; import KnowledgeBase from "@/components/layouts/KnowledgeBase"; diff --git a/web/src/app/chat/SpatialWorkspace.tsx b/web/src/app/chat/SpatialWorkspace.tsx new file mode 100644 index 00000000..81beed2a --- /dev/null +++ b/web/src/app/chat/SpatialWorkspace.tsx @@ -0,0 +1,319 @@ +import { + Background, + Node, + ReactFlow, + ReactFlowProvider, + useEdgesState, + useNodesState, + useReactFlow, +} from "@xyflow/react"; +import "@xyflow/react/dist/style.css"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; + +import { AnimatePresence } from "framer-motion"; +import { AgentNode } from "./spatial/AgentNode"; +import { FocusedView } from "./spatial/FocusedView"; +import type { AgentData, FlowAgentNodeData } from "./spatial/types"; + +type AgentFlowNode = Node; + +// --- Mock Data --- +const INITIAL_AGENTS: AgentData[] = [ + { + name: "Market Analyst Pro", + role: "Market Analyst", + desc: "Expert in trend forecasting", + avatar: "https://api.dicebear.com/7.x/avataaars/svg?seed=Market", + status: "busy", + size: "large", + }, + { + name: "Creative Writer", + role: "Copywriter", + desc: "Marketing copy & storytelling", + avatar: "https://api.dicebear.com/7.x/avataaars/svg?seed=Creative", + status: "idle", + size: "medium", + }, + { + name: "Global Search", + role: "Researcher", + desc: "Real-time info retrieval", + avatar: "https://api.dicebear.com/7.x/avataaars/svg?seed=Search", + status: "idle", + size: "small", + }, + { + name: "Code Auditor", + role: "Security", + desc: "Python/JS security checks", + avatar: "https://api.dicebear.com/7.x/avataaars/svg?seed=Code", + status: "idle", + size: "medium", + }, +]; + +const noopFocus: FlowAgentNodeData["onFocus"] = () => {}; + +const INITIAL_NODES: AgentFlowNode[] = [ + { + id: "1", + type: "agent", + position: { x: 0, y: 0 }, + data: { ...INITIAL_AGENTS[0], onFocus: noopFocus }, + }, + { + id: "2", + type: "agent", + position: { x: 600, y: -200 }, + data: { ...INITIAL_AGENTS[1], onFocus: noopFocus }, + }, + { + id: "3", + type: "agent", + position: { x: -300, y: 400 }, + data: { ...INITIAL_AGENTS[2], onFocus: noopFocus }, + }, + { + id: "4", + type: "agent", + position: { x: 700, y: 500 }, + data: { ...INITIAL_AGENTS[3], onFocus: noopFocus }, + }, +]; + +function InnerWorkspace() { + const [nodes, setNodes, onNodesChange] = + useNodesState(INITIAL_NODES); + const [edges, , onEdgesChange] = useEdgesState([]); + const [focusedAgentId, setFocusedAgentId] = useState(null); + const [prevViewport, setPrevViewport] = useState<{ + x: number; + y: number; + zoom: number; + } | null>(null); + const { setViewport, getViewport, getNode, fitView } = useReactFlow(); + const didInitialFitViewRef = useRef(false); + const cancelInitialFitRef = useRef(false); + const initialFitAttemptsRef = useRef(0); + + useEffect(() => { + if (didInitialFitViewRef.current) return; + if (cancelInitialFitRef.current) return; + + let cancelled = false; + initialFitAttemptsRef.current = 0; + + const tryFit = () => { + if (cancelled) return; + if (didInitialFitViewRef.current) return; + if (cancelInitialFitRef.current) return; + + initialFitAttemptsRef.current += 1; + + const allMeasured = nodes.every((n) => { + const node = getNode(n.id); + const w = node?.measured?.width ?? 0; + const h = node?.measured?.height ?? 0; + return w > 0 && h > 0; + }); + + // If measurement never comes through (rare), still do a best-effort fit. + if (allMeasured || initialFitAttemptsRef.current >= 12) { + didInitialFitViewRef.current = true; + fitView({ padding: 0.22, duration: 0 }); + return; + } + + requestAnimationFrame(tryFit); + }; + + requestAnimationFrame(tryFit); + return () => { + cancelled = true; + }; + }, [fitView, getNode, nodes]); + + const handleFocus = useCallback( + (id: string) => { + // Don't allow the initial fit to run after the user has started interacting. + cancelInitialFitRef.current = true; + didInitialFitViewRef.current = true; + + if (!prevViewport) { + setPrevViewport(getViewport()); + } + setFocusedAgentId(id); + + const node = getNode(id); + if (!node) return; + + const nodeW = node.measured?.width ?? 300; + const nodeH = node.measured?.height ?? 220; + const centerX = node.position.x + nodeW / 2; + const centerY = node.position.y + nodeH / 2; + + // Target Top-Left (approx 15% x, 25% y) + const targetZoom = 1.35; + const screenX = window.innerWidth * 0.15; + const screenY = window.innerHeight * 0.25; + const x = -centerX * targetZoom + screenX; + const y = -centerY * targetZoom + screenY; + + setViewport({ x, y, zoom: targetZoom }, { duration: 900 }); + }, + [getNode, getViewport, prevViewport, setViewport], + ); + + const handleCloseFocus = useCallback(() => { + setFocusedAgentId(null); + const restore = prevViewport ?? { x: 0, y: 0, zoom: 0.85 }; + setViewport(restore, { duration: 900 }); + setPrevViewport(null); + }, [prevViewport, setViewport]); + + // Inject handleFocus into node data + const nodeTypes = useMemo( + () => ({ + agent: AgentNode, + }), + [], + ); + + // Update nodes with the callback + const nodesWithHandler = useMemo(() => { + return nodes.map((n) => ({ + ...n, + data: { + ...n.data, + onFocus: handleFocus, + }, + })); + }, [nodes, handleFocus]); + + const handleNodeDragStop = useCallback( + (_: unknown, draggedNode: AgentFlowNode) => { + const padding = 24; + + const getSize = (id: string) => { + const node = getNode(id); + const measuredW = node?.measured?.width; + const measuredH = node?.measured?.height; + if (measuredW && measuredH) return { w: measuredW, h: measuredH }; + + const size = node?.data?.size; + if (size === "large") return { w: 400, h: 320 }; + if (size === "medium") return { w: 300, h: 220 }; + return { w: 200, h: 160 }; + }; + + setNodes((prev) => { + const next = prev.map((n) => ({ ...n })); + const moving = next.find((n) => n.id === draggedNode.id); + if (!moving) return prev; + + // Iteratively push the dragged node out of overlaps. + for (let iter = 0; iter < 24; iter += 1) { + let movedThisIter = false; + + const aSize = getSize(moving.id); + const ax1 = moving.position.x; + const ay1 = moving.position.y; + const ax2 = ax1 + aSize.w; + const ay2 = ay1 + aSize.h; + + for (const other of next) { + if (other.id === moving.id) continue; + + const bSize = getSize(other.id); + const bx1 = other.position.x; + const by1 = other.position.y; + const bx2 = bx1 + bSize.w; + const by2 = by1 + bSize.h; + + const overlapX = + Math.min(ax2 + padding, bx2) - Math.max(ax1 - padding, bx1); + const overlapY = + Math.min(ay2 + padding, by2) - Math.max(ay1 - padding, by1); + + if (overlapX > 0 && overlapY > 0) { + // Push along the smallest overlap axis. + if (overlapX < overlapY) { + const aCenterX = (ax1 + ax2) / 2; + const bCenterX = (bx1 + bx2) / 2; + const dir = aCenterX < bCenterX ? -1 : 1; + moving.position = { + ...moving.position, + x: moving.position.x + dir * overlapX, + }; + } else { + const aCenterY = (ay1 + ay2) / 2; + const bCenterY = (by1 + by2) / 2; + const dir = aCenterY < bCenterY ? -1 : 1; + moving.position = { + ...moving.position, + y: moving.position.y + dir * overlapY, + }; + } + + movedThisIter = true; + break; + } + } + + if (!movedThisIter) break; + } + + return next; + }); + }, + [getNode, setNodes], + ); + + const focusedAgent = useMemo(() => { + if (!focusedAgentId) return null; + return nodes.find((n) => n.id === focusedAgentId)?.data; + }, [focusedAgentId, nodes]); + + return ( +
+ + + + + + {focusedAgent && ( + ({ id: n.id, ...n.data }))} + onClose={handleCloseFocus} + onSwitchAgent={(id) => handleFocus(id)} + /> + )} + +
+ ); +} + +export function SpatialWorkspace() { + return ( + + + + ); +} diff --git a/web/src/app/chat/spatial/AgentNode.tsx b/web/src/app/chat/spatial/AgentNode.tsx new file mode 100644 index 00000000..c3806d12 --- /dev/null +++ b/web/src/app/chat/spatial/AgentNode.tsx @@ -0,0 +1,85 @@ +import { cn } from "@/lib/utils"; +import type { Node } from "@xyflow/react"; +import { NodeProps } from "@xyflow/react"; +import { motion } from "framer-motion"; +import type { FlowAgentNodeData } from "./types"; + +type AgentFlowNode = Node; + +export function AgentNode({ id, data, selected }: NodeProps) { + return ( + { + e.stopPropagation(); + data.onFocus(id); + }} + className={cn( + "relative rounded-3xl bg-[#fdfcf8] dark:bg-neutral-900/60 shadow-xl transition-all border border-white/50 dark:border-white/10 backdrop-blur-md group", + selected + ? "ring-2 ring-[#5a6e8c]/20 dark:ring-0 dark:border-indigo-400/50 dark:shadow-[0_0_15px_rgba(99,102,241,0.5),0_0_30px_rgba(168,85,247,0.3)] shadow-2xl" + : "hover:shadow-2xl", + data.size === "large" + ? "w-100 h-80 p-6" + : data.size === "medium" + ? "w-75 h-55 p-6" + : "w-50 h-40 p-4", + )} + > +
+ avatar +
+
+ {data.name} +
+
+ {data.role} +
+
+
+ + {/* Content Placeholder */} +
+ {data.status === "busy" && ( +
+
+ Processing +
+ )} + +
+ + {/* Abstract Data Viz for large cards */} + {data.size === "large" && ( +
+
+ {[40, 70, 55, 90, 60, 80].map((h, i) => ( +
+ ))} +
+
+ )} + + {/* Abstract Text for medium cards */} + {data.size === "medium" && ( +
+
+
+
+
+ )} +
+ + ); +} diff --git a/web/src/app/chat/spatial/FocusedView.tsx b/web/src/app/chat/spatial/FocusedView.tsx new file mode 100644 index 00000000..ae51a488 --- /dev/null +++ b/web/src/app/chat/spatial/FocusedView.tsx @@ -0,0 +1,176 @@ +import { motion } from "framer-motion"; +import { useEffect, useRef } from "react"; +import { AgentData } from "./types"; + +interface FocusedViewProps { + agent: AgentData; + agents: (AgentData & { id: string })[]; + onClose: () => void; + onSwitchAgent: (id: string) => void; +} + +export function FocusedView({ + agent, + agents, + onClose, + onSwitchAgent, +}: FocusedViewProps) { + const switcherRef = useRef(null); + const chatRef = useRef(null); + + useEffect(() => { + const onKeyDown = (e: KeyboardEvent) => { + if (e.key === "Escape") onClose(); + }; + + const onPointerDownCapture = (e: PointerEvent) => { + const target = e.target as HTMLElement | null; + if (!target) return; + + // Clicking on a node should focus it, not close. + if (target.closest(".react-flow__node, .xy-flow__node")) return; + + // Clicking inside UI panels should not close. + if (chatRef.current?.contains(target)) return; + if (switcherRef.current?.contains(target)) return; + + // Prevent XYFlow from starting a pan/drag on the same click, + // which can override the restore viewport animation. + e.preventDefault(); + e.stopPropagation(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (e as any).stopImmediatePropagation?.(); + + onClose(); + }; + + window.addEventListener("keydown", onKeyDown); + window.addEventListener("pointerdown", onPointerDownCapture, true); + return () => { + window.removeEventListener("keydown", onKeyDown); + window.removeEventListener("pointerdown", onPointerDownCapture, true); + }; + }, [onClose]); + + return ( +
+ {/* 1. Left Column: Top (Empty for Node visibility) + Bottom (Switcher) */} +
+ {/* Agent Switcher List */} + +
+

+ Active Agents +

+
+
+ {agents.map((a) => ( + + ))} +
+
+
+ + {/* 2. Main Chat Area */} + + {/* Chat Header */} +
+
+
+ + Session Active + +
+
{/* Tools icons or standard window controls */}
+
+ + {/* Chat Body Mock */} +
+
+
+
+

+ Hello! I'm {agent.name}. I'm ready to assist you with your tasks + today. I can access your latest files and context. +

+
+
+
+ + {/* Input Area */} +
+
+ + +
+
+ +
+ ); +} diff --git a/web/src/app/chat/spatial/types.ts b/web/src/app/chat/spatial/types.ts new file mode 100644 index 00000000..260ed208 --- /dev/null +++ b/web/src/app/chat/spatial/types.ts @@ -0,0 +1,15 @@ +export interface AgentData { + name: string; + role: string; + desc: string; + avatar: string; + status: "idle" | "busy" | "offline"; + size: "large" | "medium" | "small"; +} + +export interface AgentNodeData extends AgentData { + onFocus: (id: string) => void; +} + +// XYFlow requires node.data to be a Record +export type FlowAgentNodeData = AgentNodeData & Record; From aca0c4fec076ec0e699e890f76ff7674ad6cf187 Mon Sep 17 00:00:00 2001 From: Harvey Date: Wed, 14 Jan 2026 23:01:10 +0800 Subject: [PATCH 04/11] feat: Refactor modal components and enhance agent node resizing functionality --- web/src/app/chat/SpatialWorkspace.tsx | 46 ++- web/src/app/chat/spatial/AgentNode.tsx | 264 +++++++++++++----- web/src/app/chat/spatial/types.ts | 3 +- web/src/app/explore/McpExploreContent.tsx | 2 +- .../headless => components/animate}/modal.tsx | 0 .../components/features/ForkAgentModal.tsx | 2 +- .../components/features/PointsInfoModal.tsx | 2 +- .../components/features/PublishAgentModal.tsx | 2 +- .../components/features/TokenInputModal.tsx | 2 +- .../knowledge/CreateKnowledgeSetModal.tsx | 2 +- web/src/components/layouts/McpListModal.tsx | 2 +- web/src/components/modals/AddAgentModal.tsx | 4 +- .../components/modals/AddLlmProviderModal.tsx | 2 +- .../components/modals/AddMcpServerModal.tsx | 2 +- web/src/components/modals/EditAgentModal.tsx | 12 +- .../components/modals/EditMcpServerModal.tsx | 2 +- web/src/components/modals/SettingsModal.tsx | 2 +- web/src/components/modals/ToolTestModal.tsx | 2 +- .../modals/settings/McpSettings.tsx | 2 +- 19 files changed, 257 insertions(+), 98 deletions(-) rename web/src/components/animate-ui/{primitives/headless => components/animate}/modal.tsx (100%) diff --git a/web/src/app/chat/SpatialWorkspace.tsx b/web/src/app/chat/SpatialWorkspace.tsx index 81beed2a..caf16113 100644 --- a/web/src/app/chat/SpatialWorkspace.tsx +++ b/web/src/app/chat/SpatialWorkspace.tsx @@ -87,6 +87,7 @@ function InnerWorkspace() { useNodesState(INITIAL_NODES); const [edges, , onEdgesChange] = useEdgesState([]); const [focusedAgentId, setFocusedAgentId] = useState(null); + const containerRef = useRef(null); const [prevViewport, setPrevViewport] = useState<{ x: number; y: number; @@ -148,16 +149,28 @@ function InnerWorkspace() { const node = getNode(id); if (!node) return; - const nodeW = node.measured?.width ?? 300; - const nodeH = node.measured?.height ?? 220; - const centerX = node.position.x + nodeW / 2; + const measuredH = node.measured?.height; + const gridSize = (node.data as FlowAgentNodeData | undefined)?.gridSize; + const fallbackH = + gridSize?.w && gridSize?.h + ? gridSize.h * 160 + (gridSize.h - 1) * 16 + : 220; + + const nodeH = measuredH ?? fallbackH; const centerY = node.position.y + nodeH / 2; - // Target Top-Left (approx 15% x, 25% y) + // Focus layout: keep a consistent left padding regardless of node size. const targetZoom = 1.35; - const screenX = window.innerWidth * 0.15; - const screenY = window.innerHeight * 0.25; - const x = -centerX * targetZoom + screenX; + const rect = containerRef.current?.getBoundingClientRect(); + const containerW = rect?.width ?? window.innerWidth; + const containerH = rect?.height ?? window.innerHeight; + + const leftPadding = Math.max(24, Math.min(64, containerW * 0.08)); + const screenX = leftPadding; + const screenY = containerH * 0.25; + + // Align the node's left edge to screenX. + const x = -node.position.x * targetZoom + screenX; const y = -centerY * targetZoom + screenY; setViewport({ x, y, zoom: targetZoom }, { duration: 900 }); @@ -187,9 +200,10 @@ function InnerWorkspace() { data: { ...n.data, onFocus: handleFocus, + isFocused: n.id === focusedAgentId, }, })); - }, [nodes, handleFocus]); + }, [nodes, handleFocus, focusedAgentId]); const handleNodeDragStop = useCallback( (_: unknown, draggedNode: AgentFlowNode) => { @@ -201,7 +215,16 @@ function InnerWorkspace() { const measuredH = node?.measured?.height; if (measuredW && measuredH) return { w: measuredW, h: measuredH }; - const size = node?.data?.size; + const d = node?.data as FlowAgentNodeData | undefined; + if (d?.gridSize) { + const { w, h } = d.gridSize; + return { + w: w * 200 + (w - 1) * 16, + h: h * 160 + (h - 1) * 16, + }; + } + + const size = d?.size; if (size === "large") return { w: 400, h: 320 }; if (size === "medium") return { w: 300, h: 220 }; return { w: 200, h: 160 }; @@ -276,7 +299,10 @@ function InnerWorkspace() { }, [focusedAgentId, nodes]); return ( -
+
; -export function AgentNode({ id, data, selected }: NodeProps) { +// Helper to calc size +// Base unit: 1x1 = 200x160. Gap = 16. +const BASE_W = 200; +const BASE_H = 160; +const GAP = 16; + +const getSizeStyle = (w?: number, h?: number, sizeStr?: string) => { + if (w && h) { + return { + width: w * BASE_W + (w - 1) * GAP, + height: h * BASE_H + (h - 1) * GAP, + }; + } + // Fallback map + if (sizeStr === "large") return { width: 400, height: 320 }; // ~2x2 + if (sizeStr === "medium") return { width: 300, height: 220 }; // ~1.5? old values + if (sizeStr === "small") return { width: 200, height: 160 }; // 1x1 + return { width: 200, height: 160 }; +}; + +function GridResizer({ + currentW = 1, + currentH = 1, + onResize, +}: { + currentW?: number; + currentH?: number; + onResize: (w: number, h: number) => void; +}) { + const [hover, setHover] = useState<{ w: number; h: number } | null>(null); + return ( - { - e.stopPropagation(); - data.onFocus(id); - }} - className={cn( - "relative rounded-3xl bg-[#fdfcf8] dark:bg-neutral-900/60 shadow-xl transition-all border border-white/50 dark:border-white/10 backdrop-blur-md group", - selected - ? "ring-2 ring-[#5a6e8c]/20 dark:ring-0 dark:border-indigo-400/50 dark:shadow-[0_0_15px_rgba(99,102,241,0.5),0_0_30px_rgba(168,85,247,0.3)] shadow-2xl" - : "hover:shadow-2xl", - data.size === "large" - ? "w-100 h-80 p-6" - : data.size === "medium" - ? "w-75 h-55 p-6" - : "w-50 h-40 p-4", - )} - > -
- avatar -
-
- {data.name} -
-
- {data.role} -
+
+
+ Adjust the grid size of this agent widget. +
+
+
setHover(null)} + > + {Array.from({ length: 9 }).map((_, i) => { + const x = (i % 3) + 1; + const y = Math.floor(i / 3) + 1; + const isHovered = hover && x <= hover.w && y <= hover.h; + const isSelected = !hover && x <= currentW && y <= currentH; + + return ( +
setHover({ w: x, h: y })} + onClick={(e) => { + e.preventDefault(); + e.stopPropagation(); + onResize(x, y); + }} + /> + ); + })}
+
+ {hover ? `${hover.w} x ${hover.h}` : `${currentW} x ${currentH}`} +
+
+ ); +} - {/* Content Placeholder */} -
- {data.status === "busy" && ( -
-
- Processing -
+export function AgentNode({ id, data, selected }: NodeProps) { + const { updateNodeData } = useReactFlow(); + const [isSettingsOpen, setIsSettingsOpen] = useState(false); + // Determine current dim + const currentW = data.gridSize?.w || (data.size === "large" ? 2 : 1); + const currentH = data.gridSize?.h || (data.size === "large" ? 2 : 1); + + const style = getSizeStyle(data.gridSize?.w, data.gridSize?.h, data.size); + + return ( + <> + setIsSettingsOpen(false)} + title="Widget Settings" + maxWidth="max-w-xs" + > + { + updateNodeData(id, { + gridSize: { w, h }, + size: w * h > 3 ? "large" : w * h > 1 ? "medium" : "small", + }); + // Optional: Close modal after selection if desired, or keep open + // setIsSettingsOpen(false); + }} + /> + + + { + // Only trigger focus if we are NOT clicking inside the settings menu interactions + e.stopPropagation(); + data.onFocus(id); + }} + className={cn( + "relative group rounded-3xl", // Removed bg/border from here + data.isFocused ? "z-50" : "z-0", // focused node higher z-index )} + > + {/* IsFocused Glow - BEHIND CARD */} + {data.isFocused && ( + + )} + + {/* Card Background Layer - Acts as the solid surface */} +
-
- - {/* Abstract Data Viz for large cards */} - {data.size === "large" && ( -
-
- {[40, 70, 55, 90, 60, 80].map((h, i) => ( -
- ))} + {/* Content Container - On Top */} +
+
+ +
+ +
+ avatar +
+
+ {data.name} +
+
+ {data.role} +
- )} - {/* Abstract Text for medium cards */} - {data.size === "medium" && ( -
-
-
-
+
+ {data.status === "busy" && ( +
+
+ Processing +
+ )} + +
+ + {/* Dynamic Abstract Viz based on size */} + {((data.gridSize && data.gridSize.w * data.gridSize.h >= 2) || + data.size === "large" || + data.size === "medium") && ( +
+
+ {[40, 70, 55, 90, 60, 80] + .slice(0, currentW * currentH + 2) + .map((h, i) => ( +
+ ))} +
+
+ )}
- )} -
- +
+ + ); } diff --git a/web/src/app/chat/spatial/types.ts b/web/src/app/chat/spatial/types.ts index 260ed208..face384c 100644 --- a/web/src/app/chat/spatial/types.ts +++ b/web/src/app/chat/spatial/types.ts @@ -5,10 +5,11 @@ export interface AgentData { avatar: string; status: "idle" | "busy" | "offline"; size: "large" | "medium" | "small"; + gridSize?: { w: number; h: number }; // 1-3 grid system } - export interface AgentNodeData extends AgentData { onFocus: (id: string) => void; + isFocused?: boolean; } // XYFlow requires node.data to be a Record diff --git a/web/src/app/explore/McpExploreContent.tsx b/web/src/app/explore/McpExploreContent.tsx index 8dcd918b..6b9c481c 100644 --- a/web/src/app/explore/McpExploreContent.tsx +++ b/web/src/app/explore/McpExploreContent.tsx @@ -1,5 +1,5 @@ "use client"; -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import McpServerDetail from "@/marketplace/components/McpServerDetail"; import SmitheryServerDetail from "@/marketplace/components/SmitheryServerDetail"; import UnifiedMcpMarketList from "@/marketplace/components/UnifiedMcpMarketList"; diff --git a/web/src/components/animate-ui/primitives/headless/modal.tsx b/web/src/components/animate-ui/components/animate/modal.tsx similarity index 100% rename from web/src/components/animate-ui/primitives/headless/modal.tsx rename to web/src/components/animate-ui/components/animate/modal.tsx diff --git a/web/src/components/features/ForkAgentModal.tsx b/web/src/components/features/ForkAgentModal.tsx index 614c60db..6c781dd9 100644 --- a/web/src/components/features/ForkAgentModal.tsx +++ b/web/src/components/features/ForkAgentModal.tsx @@ -1,6 +1,6 @@ "use client"; -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { Input } from "@/components/base/Input"; import { useForkAgent } from "@/hooks/useMarketplace"; import { Button, Field, Label } from "@headlessui/react"; diff --git a/web/src/components/features/PointsInfoModal.tsx b/web/src/components/features/PointsInfoModal.tsx index d16b141b..2cca4224 100644 --- a/web/src/components/features/PointsInfoModal.tsx +++ b/web/src/components/features/PointsInfoModal.tsx @@ -1,4 +1,4 @@ -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { BoltIcon, CurrencyYenIcon, diff --git a/web/src/components/features/PublishAgentModal.tsx b/web/src/components/features/PublishAgentModal.tsx index 46bb01c7..cee1dcee 100644 --- a/web/src/components/features/PublishAgentModal.tsx +++ b/web/src/components/features/PublishAgentModal.tsx @@ -1,6 +1,6 @@ "use client"; -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { usePublishAgent } from "@/hooks/useMarketplace"; import { Button, Field, Label, Switch } from "@headlessui/react"; import { diff --git a/web/src/components/features/TokenInputModal.tsx b/web/src/components/features/TokenInputModal.tsx index 032915e4..7fd2e041 100644 --- a/web/src/components/features/TokenInputModal.tsx +++ b/web/src/components/features/TokenInputModal.tsx @@ -1,4 +1,4 @@ -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { useState } from "react"; export interface TokenInputModalProps { diff --git a/web/src/components/knowledge/CreateKnowledgeSetModal.tsx b/web/src/components/knowledge/CreateKnowledgeSetModal.tsx index 19ca9845..e78c8e6f 100644 --- a/web/src/components/knowledge/CreateKnowledgeSetModal.tsx +++ b/web/src/components/knowledge/CreateKnowledgeSetModal.tsx @@ -1,4 +1,4 @@ -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { useEffect, useMemo, useState } from "react"; import { useTranslation } from "react-i18next"; diff --git a/web/src/components/layouts/McpListModal.tsx b/web/src/components/layouts/McpListModal.tsx index 793c5e05..215ad7a2 100644 --- a/web/src/components/layouts/McpListModal.tsx +++ b/web/src/components/layouts/McpListModal.tsx @@ -1,5 +1,5 @@ +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { LiquidButton } from "@/components/animate-ui/primitives/buttons/liquid"; -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; import { LoadingSpinner } from "@/components/base/LoadingSpinner"; import { AddMcpServerModal } from "@/components/modals/AddMcpServerModal"; import { EditMcpServerModal } from "@/components/modals/EditMcpServerModal"; diff --git a/web/src/components/modals/AddAgentModal.tsx b/web/src/components/modals/AddAgentModal.tsx index 25ccb3ea..5010c2f9 100644 --- a/web/src/components/modals/AddAgentModal.tsx +++ b/web/src/components/modals/AddAgentModal.tsx @@ -1,4 +1,4 @@ -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { Input } from "@/components/base/Input"; import { useXyzen } from "@/store"; import type { Agent, SystemAgentTemplate } from "@/types/agents"; @@ -13,8 +13,8 @@ import { TabPanels, } from "@headlessui/react"; import { - PlusIcon, BeakerIcon, + PlusIcon, SparklesIcon, } from "@heroicons/react/24/outline"; import React, { useEffect, useState } from "react"; diff --git a/web/src/components/modals/AddLlmProviderModal.tsx b/web/src/components/modals/AddLlmProviderModal.tsx index ade2eca1..111d0887 100644 --- a/web/src/components/modals/AddLlmProviderModal.tsx +++ b/web/src/components/modals/AddLlmProviderModal.tsx @@ -1,4 +1,4 @@ -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { Input } from "@/components/base/Input"; import { useXyzen } from "@/store"; import type { LlmProviderCreate } from "@/types/llmProvider"; diff --git a/web/src/components/modals/AddMcpServerModal.tsx b/web/src/components/modals/AddMcpServerModal.tsx index 38d7890d..a8ebfef4 100644 --- a/web/src/components/modals/AddMcpServerModal.tsx +++ b/web/src/components/modals/AddMcpServerModal.tsx @@ -1,9 +1,9 @@ +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { FlipButton, FlipButtonBack, FlipButtonFront, } from "@/components/animate-ui/components/buttons/flip"; -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; import { Input } from "@/components/base/Input"; import { useXyzen } from "@/store"; import type { McpServerCreate } from "@/types/mcp"; diff --git a/web/src/components/modals/EditAgentModal.tsx b/web/src/components/modals/EditAgentModal.tsx index e892bf11..7c99bcbe 100644 --- a/web/src/components/modals/EditAgentModal.tsx +++ b/web/src/components/modals/EditAgentModal.tsx @@ -1,15 +1,15 @@ -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { Input } from "@/components/base/Input"; -import PublishAgentModal from "@/components/features/PublishAgentModal"; import { AgentGraphEditor } from "@/components/editors/AgentGraphEditor"; import { JsonEditor } from "@/components/editors/JsonEditor"; +import PublishAgentModal from "@/components/features/PublishAgentModal"; import { useXyzen } from "@/store"; import type { Agent } from "@/types/agents"; import type { GraphConfig } from "@/types/graphConfig"; import { extractSimpleConfig, - mergeSimpleConfigToGraphConfig, isStandardReactPattern, + mergeSimpleConfigToGraphConfig, type SimpleAgentConfig, } from "@/utils/agentConfigMapper"; import { @@ -23,12 +23,12 @@ import { TabPanels, } from "@headlessui/react"; import { + CodeBracketIcon, + CubeTransparentIcon, PlusIcon, SparklesIcon, - CubeTransparentIcon, - CodeBracketIcon, } from "@heroicons/react/24/outline"; -import React, { useEffect, useState, useCallback, useMemo } from "react"; +import React, { useCallback, useEffect, useMemo, useState } from "react"; import { useTranslation } from "react-i18next"; import { McpServerItem } from "./McpServerItem"; diff --git a/web/src/components/modals/EditMcpServerModal.tsx b/web/src/components/modals/EditMcpServerModal.tsx index 830795ce..bb2e0e09 100644 --- a/web/src/components/modals/EditMcpServerModal.tsx +++ b/web/src/components/modals/EditMcpServerModal.tsx @@ -1,4 +1,4 @@ -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { Input } from "@/components/base/Input"; import { useXyzen } from "@/store"; import type { McpServer, McpServerUpdate } from "@/types/mcp"; diff --git a/web/src/components/modals/SettingsModal.tsx b/web/src/components/modals/SettingsModal.tsx index d38794ba..e6171092 100644 --- a/web/src/components/modals/SettingsModal.tsx +++ b/web/src/components/modals/SettingsModal.tsx @@ -1,4 +1,4 @@ -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { useXyzen } from "@/store"; import { AdjustmentsHorizontalIcon, diff --git a/web/src/components/modals/ToolTestModal.tsx b/web/src/components/modals/ToolTestModal.tsx index d662c868..bf0ab396 100644 --- a/web/src/components/modals/ToolTestModal.tsx +++ b/web/src/components/modals/ToolTestModal.tsx @@ -1,4 +1,4 @@ -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { LoadingSpinner } from "@/components/base/LoadingSpinner"; import { useXyzen } from "@/store"; import type { McpServer } from "@/types/mcp"; diff --git a/web/src/components/modals/settings/McpSettings.tsx b/web/src/components/modals/settings/McpSettings.tsx index e52eea72..92a2da76 100644 --- a/web/src/components/modals/settings/McpSettings.tsx +++ b/web/src/components/modals/settings/McpSettings.tsx @@ -1,6 +1,6 @@ +import { Modal } from "@/components/animate-ui/components/animate/modal"; import { Button } from "@/components/animate-ui/primitives/buttons/button"; import { LiquidButton } from "@/components/animate-ui/primitives/buttons/liquid"; -import { Modal } from "@/components/animate-ui/primitives/headless/modal"; import { LoadingSpinner } from "@/components/base/LoadingSpinner"; import { AddMcpServerModal } from "@/components/modals/AddMcpServerModal"; import { EditMcpServerModal } from "@/components/modals/EditMcpServerModal"; From bee3b888f5f713c724a1ef26690e81c64726dd93 Mon Sep 17 00:00:00 2001 From: "xinquiry(SII)" <100398322+xinquiry@users.noreply.github.com> Date: Wed, 14 Jan 2026 23:06:46 +0800 Subject: [PATCH 05/11] feat: add tier-based consumption pricing with strategy pattern (#165) * feat: remove user define provider * feat: add model_tier * feat: model tier mvp * fix: fix frontend graph edit * feat: add intelligent model selection and Redis stateless infrastructure - Add LLM-based model selector that chooses optimal model for user's task - Add Redis infrastructure for multi-pod deployments (cache, pub/sub) - Add Redis-backed token cache with local fallback - Add RedisBroadcastManager for cross-pod WebSocket broadcasts - Fix model-to-provider resolution to ensure correct provider is used - Use shared Redis client in topic_generator to prevent connection leaks - Add error handling for MCP WebSocket Redis subscriber Co-Authored-By: Claude * feat: add missing translation * fix: modify translation * feat: add tier-based consumption pricing with strategy pattern - Implement TierBasedConsumptionStrategy with configurable rates: - ULTRA: 6.8x, PRO: 3.0x, STANDARD: 1.0x, LITE: free - Add consumption tracking fields to ConsumeRecord model - Create database migration for tier pricing columns - Fix new agent model resolution fallback to STANDARD tier - Support ProviderType enum in create_langchain_model - Improve model selector prompt and add debug logging - Add fallback markers to STANDARD and LITE tier candidates - Update token rates (input: 0.2/1000, output: 1/1000) - Add unit tests for consumption strategies Co-Authored-By: Claude --------- Co-authored-by: Claude --- AGENTS.md | 4 + service/app/agents/graph_builder.py | 14 +- service/app/api/v1/providers.py | 46 +- service/app/api/ws/v1/chat.py | 12 +- service/app/api/ws/v1/mcp.py | 17 +- service/app/configs/redis.py | 9 + service/app/core/chat/langchain.py | 107 +++- service/app/core/chat/model_selector.py | 244 ++++++++ service/app/core/chat/topic_generator.py | 74 +-- service/app/core/consume.py | 18 + service/app/core/consume_calculator.py | 59 ++ service/app/core/consume_strategy.py | 118 ++++ service/app/core/model_registry/service.py | 116 +++- service/app/core/providers/manager.py | 32 +- service/app/core/websocket.py | 105 +++- service/app/infra/redis/__init__.py | 92 +++ service/app/middleware/auth/cache.py | 179 +++++- service/app/middleware/auth/simple_cache.py | 175 +++++- service/app/models/consume.py | 8 + service/app/models/sessions.py | 5 + service/app/repos/provider.py | 14 + service/app/repos/session.py | 24 +- service/app/schemas/model_tier.py | 218 +++++++ service/app/tasks/chat.py | 31 +- .../ab49b572f009_add_tier_to_session.py | 39 ++ ...712246c7034_add_tier_pricing_to_consume.py | 40 ++ .../test_handlers/test_provider_api.py | 127 ++-- .../chat/test_topic_generator_selector.py | 35 -- .../unit/test_core/test_consume_strategy.py | 239 ++++++++ .../unit/test_schemas/test_model_tier.py | 99 ++++ .../components/features/SettingsButton.tsx | 2 +- .../layouts/components/ChatToolbar.tsx | 70 +-- .../layouts/components/TierSelector.tsx | 137 +++++ .../components/modals/AddLlmProviderModal.tsx | 253 -------- web/src/components/modals/SettingsModal.tsx | 51 -- .../modals/settings/ProviderConfigForm.tsx | 545 ------------------ .../modals/settings/ProviderList.tsx | 213 ------- web/src/components/modals/settings/index.ts | 2 - web/src/core/session/types.ts | 2 + web/src/hooks/queries/index.ts | 1 + web/src/hooks/queries/queryKeys.ts | 1 + web/src/hooks/queries/useProvidersQuery.ts | 16 + web/src/i18n/locales/en/app.json | 21 + web/src/i18n/locales/ja/app.json | 21 + web/src/i18n/locales/zh/app.json | 21 + web/src/service/llmProviderService.ts | 16 + web/src/service/sessionService.ts | 3 + web/src/store/slices/chatSlice.ts | 3 + web/src/store/types.ts | 2 + 49 files changed, 2326 insertions(+), 1354 deletions(-) create mode 100644 service/app/core/chat/model_selector.py create mode 100644 service/app/core/consume_calculator.py create mode 100644 service/app/core/consume_strategy.py create mode 100644 service/app/infra/redis/__init__.py create mode 100644 service/app/schemas/model_tier.py create mode 100644 service/migrations/versions/ab49b572f009_add_tier_to_session.py create mode 100644 service/migrations/versions/c712246c7034_add_tier_pricing_to_consume.py delete mode 100644 service/tests/unit/core/chat/test_topic_generator_selector.py create mode 100644 service/tests/unit/test_core/test_consume_strategy.py create mode 100644 service/tests/unit/test_schemas/test_model_tier.py create mode 100644 web/src/components/layouts/components/TierSelector.tsx delete mode 100644 web/src/components/modals/AddLlmProviderModal.tsx delete mode 100644 web/src/components/modals/settings/ProviderConfigForm.tsx delete mode 100644 web/src/components/modals/settings/ProviderList.tsx diff --git a/AGENTS.md b/AGENTS.md index 01e138b4..dc890b83 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -72,6 +72,10 @@ types/ ## Core Patterns +**Stateless Async Execution**: Decouple connection management (FastAPI) from heavy computation (Celery). +* State Offloading: API containers remain stateless. Ephemeral state (Queues, Pub/Sub channels) resides in Redis; persistent state in DB. +* Pub/Sub Bridge: Workers process tasks independently and broadcast results back to the specific API pod via Redis channels (chat:{connection_id}), enabling independent scaling of Web and Worker layers. + **No-Foreign-Key Database**: Use logical references (`user_id: str`) instead of FK constraints. Handle relationships in service layer. **Repository Pattern**: Data access via `repos/` classes. Business logic in `core/` services. diff --git a/service/app/agents/graph_builder.py b/service/app/agents/graph_builder.py index 2940d5cb..6e22e34e 100644 --- a/service/app/agents/graph_builder.py +++ b/service/app/agents/graph_builder.py @@ -377,7 +377,19 @@ async def llm_node(state: StateDict) -> StateDict: messages = messages + [HumanMessage(content=prompt)] # Invoke LLM - response = await llm.ainvoke(messages) + try: + response = await llm.ainvoke(messages) + except ValueError as e: + if "expected value at line 1 column 1" in str(e): + logger.error( + f"[LLM Node: {config.id}] JSON parsing failed. This often happens due to " + "Azure Content Filters or empty model response." + ) + raise ValueError( + "Model execution failed: The provider returned an invalid response. " + "This may be due to safety filters blocking the content." + ) from e + raise # Handle structured output if structured_model and isinstance(response, BaseModel): diff --git a/service/app/api/v1/providers.py b/service/app/api/v1/providers.py index 5d9cbd97..c190b49c 100644 --- a/service/app/api/v1/providers.py +++ b/service/app/api/v1/providers.py @@ -126,7 +126,8 @@ async def get_available_models_for_user( Dict[str, List[ModelInfo]]: Dictionary mapping provider ID (as string) to list of ModelInfo """ provider_repo = ProviderRepository(db) - providers = await provider_repo.get_providers_by_user(user_id, include_system=True) + # Only return models for system providers (user-defined providers are disabled) + providers = await provider_repo.get_all_system_providers() result: dict[str, list[ModelInfo]] = {} @@ -173,9 +174,9 @@ async def get_system_providers( db: AsyncSession = Depends(get_session), ) -> list[ProviderRead]: """ - Get all providers accessible to the current authenticated user. + Get all system providers. - Includes both user's own providers and system providers. System provider + Returns all system providers configured via environment variables. API keys and endpoints are masked for security reasons. Args: @@ -183,17 +184,17 @@ async def get_system_providers( db: Database session (injected by dependency) Returns: - List[ProviderRead]: List of providers accessible to the user + List[ProviderRead]: List of system providers Raises: - HTTPException: None - this endpoint always succeeds, returning empty list if no providers + HTTPException: 404 if no system providers found """ provider_repo = ProviderRepository(db) - provider = await provider_repo.get_system_provider() - if not provider: - raise HTTPException(status_code=404, detail="System provider not found") + providers = await provider_repo.get_all_system_providers() + if not providers: + raise HTTPException(status_code=404, detail="No system providers found") - return [_sanitize_provider_read(provider)] + return [_sanitize_provider_read(p) for p in providers] @router.get("/me", response_model=list[ProviderRead]) @@ -202,20 +203,24 @@ async def get_my_providers( db: AsyncSession = Depends(get_session), ) -> list[ProviderRead]: """ - Get all user-scoped providers owned by the current authenticated user. + Get all providers accessible to the current authenticated user. + + Note: User-defined providers are disabled. This endpoint now returns + only system providers. Args: - user: Authenticated user ID (injected by dependency) + user_id: Authenticated user ID (injected by dependency) db: Database session (injected by dependency) Returns: - List[ProviderRead]: List of providers accessible to the user + List[ProviderRead]: List of system providers accessible to the user Raises: HTTPException: None - this endpoint always succeeds, returning empty list if no providers """ provider_repo = ProviderRepository(db) - providers = await provider_repo.get_providers_by_user(user_id, include_system=True) + # Only return system providers (user-defined providers are disabled) + providers = await provider_repo.get_all_system_providers() return [_sanitize_provider_read(p) for p in providers] @@ -228,20 +233,27 @@ async def create_provider( """ Create a new provider for the current authenticated user. - The user_id is automatically set from the authenticated user context. - If this is the user's first provider, it will automatically be set as default. + NOTE: User-defined provider creation is disabled. Only system providers + configured via environment variables are supported. Args: provider_data: Provider creation data - user: Authenticated user ID (injected by dependency) + user_id: Authenticated user ID (injected by dependency) db: Database session (injected by dependency) Returns: ProviderRead: The newly created provider Raises: - HTTPException: 400 if invalid provider_type, 500 if creation fails + HTTPException: 403 - User provider creation is disabled """ + # User-defined providers are disabled - only system providers are allowed + raise HTTPException( + status_code=403, + detail="User-defined provider creation is disabled. Only system providers are available.", + ) + + # Original code below is unreachable but kept for reference # Validate provider_type try: ProviderType(provider_data.provider_type) diff --git a/service/app/api/ws/v1/chat.py b/service/app/api/ws/v1/chat.py index db719357..c44b715a 100644 --- a/service/app/api/ws/v1/chat.py +++ b/service/app/api/ws/v1/chat.py @@ -150,12 +150,8 @@ async def chat_websocket( ) await db.flush() - # 3. Echo user message (via Redis for consistency? No, echo immediately usually better UI) - # But to keep consistent order, maybe better to just let frontend handle optimistic update? - # Original code sent it back. We will send it back via WS directly or relying on frontend? - # Original: sent back "user_message_with_files". + # 3. Echo user message user_message_with_files = await message_repo.get_message_with_files(user_message.id) - # We can send this directly via WS since we are in the loop if user_message_with_files: await websocket.send_text(user_message_with_files.model_dump_json()) else: @@ -221,7 +217,7 @@ async def chat_websocket( access_token=auth_ctx.access_token if auth_ctx.auth_provider.lower() == "bohr_app" else None, ) - # 7. Topic Renaming (Keep local async task or move to Celery too? Local is fine for now) + # 7. Topic Renaming - uses Redis pub/sub for cross-pod delivery topic_refreshed = await topic_repo.get_topic_with_details(topic_id) if topic_refreshed and topic_refreshed.name in ["新的聊天", "New Chat", "New Topic"]: msgs = await message_repo.get_messages_by_topic(topic_id, limit=5) @@ -232,13 +228,9 @@ async def chat_websocket( topic_id, session_id, auth_ctx.user_id, - manager, # Passing manager might be tricky if it uses non-redis send connection_id, ) ) - # NOTE: generate_and_update_topic_title uses manager.send_personal_message - # Our local 'manager' here sends via WS. So it works as long as THIS websocket is open. - # If user disconnects, title update might fail to notify FE, but DB update happens. except WebSocketDisconnect: logger.info(f"WebSocket disconnected: {connection_id}") diff --git a/service/app/api/ws/v1/mcp.py b/service/app/api/ws/v1/mcp.py index 6cb3e4b1..a8538eae 100644 --- a/service/app/api/ws/v1/mcp.py +++ b/service/app/api/ws/v1/mcp.py @@ -1,13 +1,28 @@ +import logging + from fastapi import APIRouter, WebSocket, WebSocketDisconnect from app.core.websocket import mcp_websocket_manager +logger = logging.getLogger(__name__) + router = APIRouter(tags=["MCP Updates"]) @router.websocket("") async def websocket_endpoint(websocket: WebSocket) -> None: - """WebSocket endpoint for MCP server status updates.""" + """WebSocket endpoint for MCP server status updates. + + This endpoint subscribes to Redis pub/sub for cross-pod broadcasts, + ensuring clients receive updates regardless of which pod they connect to. + """ + # Start Redis subscriber if not already running (non-blocking, logs errors internally) + try: + await mcp_websocket_manager.start_subscriber() + except Exception as e: + # Log but don't fail - local broadcasts will still work + logger.warning(f"Failed to start Redis subscriber for MCP updates: {e}") + await mcp_websocket_manager.connect(websocket) try: while True: diff --git a/service/app/configs/redis.py b/service/app/configs/redis.py index cdcb900c..89ca4be7 100644 --- a/service/app/configs/redis.py +++ b/service/app/configs/redis.py @@ -1,3 +1,6 @@ +from typing import Literal + +from pydantic import Field from pydantic_settings import BaseSettings @@ -7,6 +10,12 @@ class RedisConfig(BaseSettings): DB: int = 0 PASSWORD: str | None = None + # Cache backend: "local" for in-memory (single pod), "redis" for distributed + CacheBackend: Literal["local", "redis"] = Field( + default="redis", + description="Cache backend for token auth and other caches. Use 'redis' for multi-pod deployments.", + ) + @property def REDIS_URL(self) -> str: if self.PASSWORD: diff --git a/service/app/core/chat/langchain.py b/service/app/core/chat/langchain.py index d147ebf8..1b87d575 100644 --- a/service/app/core/chat/langchain.py +++ b/service/app/core/chat/langchain.py @@ -84,7 +84,13 @@ async def get_ai_response_stream_langchain_legacy( agent = await _resolve_agent(db, agent, topic) # Determine provider and model - provider_id, model_name = await _resolve_provider_and_model(db, agent, topic) + provider_id, model_name = await _resolve_provider_and_model( + db=db, + agent=agent, + topic=topic, + message_text=message_text, + user_provider_manager=user_provider_manager, + ) # Build system prompt system_prompt = await build_system_prompt(db, agent, model_name) @@ -143,14 +149,25 @@ async def _resolve_agent(db: AsyncSession, agent: "Agent | None", topic: TopicMo async def _resolve_provider_and_model( - db: AsyncSession, agent: "Agent | None", topic: TopicModel + db: AsyncSession, + agent: "Agent | None", + topic: TopicModel, + message_text: str | None = None, + user_provider_manager: Any = None, ) -> tuple[str | None, str | None]: """ Determine provider and model to use. - Priority: Session Override > Agent Default > System Default + Priority: Session Model > Session Tier (with intelligent selection) > Agent Default > System Default (STANDARD tier) + + When model_tier is set but no session.model: + - Uses intelligent selection to pick the best model for the task + - Caches the selection in session.model for subsequent messages """ from app.repos.session import SessionRepository + from app.schemas.model_tier import ModelTier, get_candidate_for_model, resolve_model_for_tier + + from .model_selector import select_model_for_tier session_repo = SessionRepository(db) session = await session_repo.get_session_by_id(topic.session_id) @@ -161,8 +178,41 @@ async def _resolve_provider_and_model( if session: if session.provider_id: provider_id = str(session.provider_id) + + # If session.model is already set, use it directly (cached selection) if session.model: model_name = session.model + logger.info(f"Using cached session model: {model_name}") + # If model_tier is set but no model, do intelligent selection + elif session.model_tier: + if message_text and user_provider_manager: + try: + model_name = await select_model_for_tier( + tier=session.model_tier, + first_message=message_text, + user_provider_manager=user_provider_manager, + ) + logger.info(f"Intelligent selection chose model: {model_name} for tier {session.model_tier.value}") + + # Cache the selection in session.model for subsequent messages. + # Note: Concurrent requests may race to set this, but that's fine since + # they would set the same (or equivalent) value - this is idempotent. + from app.models.sessions import SessionUpdate + + await session_repo.update_session( + session_id=session.id, + session_update=SessionUpdate(model=model_name), + ) + await db.commit() + logger.info(f"Cached selected model in session: {model_name}") + except Exception as e: + logger.error(f"Intelligent model selection failed: {e}") + model_name = resolve_model_for_tier(session.model_tier) + logger.warning(f"Falling back to tier default: {model_name}") + else: + # No message or provider manager, use simple fallback + model_name = resolve_model_for_tier(session.model_tier) + logger.info(f"Using tier fallback (no context): {model_name}") if not provider_id and agent and agent.provider_id: provider_id = str(agent.provider_id) @@ -170,6 +220,41 @@ async def _resolve_provider_and_model( if not model_name and agent and agent.model: model_name = agent.model + # Final fallback: if still no model, use STANDARD tier default + # This handles cases where session has no model/tier and agent has no default model + if not model_name: + default_tier = ModelTier.STANDARD + if message_text and user_provider_manager: + try: + model_name = await select_model_for_tier( + tier=default_tier, + first_message=message_text, + user_provider_manager=user_provider_manager, + ) + logger.info(f"Using STANDARD tier selection: {model_name}") + except Exception as e: + logger.error(f"STANDARD tier selection failed: {e}") + model_name = resolve_model_for_tier(default_tier) + logger.warning(f"Falling back to STANDARD tier default: {model_name}") + else: + model_name = resolve_model_for_tier(default_tier) + logger.info(f"Using STANDARD tier fallback: {model_name}") + + # Ensure we have the correct provider for the selected model. + # The model's required provider takes precedence over session.provider_id. + if model_name and user_provider_manager: + candidate = get_candidate_for_model(model_name) + if candidate: + # Find a configured provider that matches the candidate's required type + for config in user_provider_manager.list_providers(): + if config.provider_type == candidate.provider_type: + provider_id = config.name + logger.info( + f"Resolved provider {provider_id} for model {model_name} (type: {candidate.provider_type})" + ) + break + + logger.info(f"Resolved model: {model_name}, provider: {provider_id}") return provider_id, model_name @@ -354,10 +439,18 @@ async def _handle_updates_mode(data: Any, ctx: StreamContext) -> AsyncGenerator[ # Structured output nodes (clarify_with_user, write_research_brief, etc.) # These nodes use with_structured_output and don't stream normally # They return clean content in messages, so we emit it as if streamed - elif hasattr(last_message, "content") and step_name in { - "clarify_with_user", - "write_research_brief", - }: + # Check for structured output metadata to support custom nodes + elif hasattr(last_message, "content") and ( + step_name + in { + "clarify_with_user", + "write_research_brief", + } + or ( + hasattr(last_message, "additional_kwargs") + and "structured_output" in last_message.additional_kwargs.get("node_metadata", {}) + ) + ): content = last_message.content if isinstance(content, str) and content: logger.debug("Structured output from '%s': %s", step_name, content[:100]) diff --git a/service/app/core/chat/model_selector.py b/service/app/core/chat/model_selector.py new file mode 100644 index 00000000..beb0e7e3 --- /dev/null +++ b/service/app/core/chat/model_selector.py @@ -0,0 +1,244 @@ +"""Intelligent model selection for tier-based model routing. + +This module uses a small LLM (Gemini 2.5 Flash) to analyze the user's first message +and select the optimal model from the available candidates in their chosen tier. +""" + +import logging +from typing import TYPE_CHECKING + +from langchain_core.messages import HumanMessage + +from app.schemas.model_tier import ( + MODEL_SELECTOR_MODEL, + MODEL_SELECTOR_PROVIDER, + TIER_MODEL_CANDIDATES, + ModelTier, + TierModelCandidate, + get_fallback_model_for_tier, +) +from app.schemas.provider import ProviderType + +if TYPE_CHECKING: + from app.core.providers.manager import ProviderManager + +logger = logging.getLogger(__name__) + + +# Editable prompt template for model selection +MODEL_SELECTION_PROMPT = """You are a model router. Select the best model for the user's task. + +Available models: +{available_models} + +User's task: +{user_message} + +Instructions: +- Pick the model that best matches the task requirements +- If the task needs image generation, pick an image-capable model +- You MUST select one model from the list above +- Return ONLY the exact model ID (e.g., "gemini-3-pro-preview"), nothing else""" + + +def _get_available_provider_types(user_provider_manager: "ProviderManager") -> set[ProviderType]: + """Get the set of available provider types from the provider manager. + + Args: + user_provider_manager: The user's provider manager + + Returns: + Set of available ProviderType values + """ + providers = user_provider_manager.list_providers() + available_types = {cfg.provider_type for cfg in providers} + logger.info(f"Available provider types: {[t.value for t in available_types]}") + return available_types + + +def _filter_candidates_by_availability( + tier: ModelTier, + available_types: set[ProviderType], +) -> tuple[list[TierModelCandidate], TierModelCandidate | None]: + """Filter tier candidates to only those with available providers. + + Args: + tier: The model tier + available_types: Set of available provider types + + Returns: + Tuple of (available_candidates, fallback_candidate) + """ + candidates = TIER_MODEL_CANDIDATES.get(tier, TIER_MODEL_CANDIDATES[ModelTier.STANDARD]) + available_candidates: list[TierModelCandidate] = [] + fallback: TierModelCandidate | None = None + + for candidate in candidates: + logger.debug( + f"Checking candidate {candidate.model}: provider={candidate.provider_type.value}, " + f"in_available={candidate.provider_type in available_types}, is_fallback={candidate.is_fallback}" + ) + if candidate.provider_type in available_types: + if candidate.is_fallback: + fallback = candidate + else: + available_candidates.append(candidate) + + # Sort by priority (lower = higher priority) + available_candidates.sort(key=lambda c: c.priority) + + logger.info( + f"Tier {tier.value}: {len(available_candidates)} available candidates, " + f"fallback={'available' if fallback else 'not available'}" + ) + for c in available_candidates: + logger.info(f" - {c.model} (priority={c.priority}, provider={c.provider_type.value})") + + return available_candidates, fallback + + +def _format_available_models_for_prompt(candidates: list[TierModelCandidate]) -> str: + """Format available models for the LLM prompt. + + Args: + candidates: List of available model candidates + + Returns: + Formatted string for the prompt + """ + lines = [] + for c in candidates: + lines.append(f"- {c.model}: {c.description}") + return "\n".join(lines) + + +async def select_model_for_tier( + tier: ModelTier, + first_message: str, + user_provider_manager: "ProviderManager", +) -> str: + """ + Intelligently select a model for the given tier based on the first message. + + Uses Gemini 2.5 Flash to analyze the task and select the optimal model. + Falls back to the tier's fallback model if selection fails. + + Args: + tier: The user-selected model tier + first_message: The user's first message in the session + user_provider_manager: The user's provider manager + + Returns: + The selected model name + """ + logger.info(f"Starting model selection for tier: {tier.value}") + logger.debug(f"First message: {first_message[:200]}...") + + # Get available providers + available_types = _get_available_provider_types(user_provider_manager) + + # Filter candidates by availability + available_candidates, fallback = _filter_candidates_by_availability(tier, available_types) + + # If no candidates available, use fallback + if not available_candidates: + if fallback: + logger.warning(f"No candidates available for tier {tier.value}, using fallback: {fallback.model}") + return fallback.model + else: + fallback_candidate = get_fallback_model_for_tier(tier) + logger.warning( + f"No candidates or fallback available for tier {tier.value}, " + f"using tier fallback: {fallback_candidate.model}" + ) + return fallback_candidate.model + + # If only one candidate, use it directly (skip LLM call) + if len(available_candidates) == 1: + selected = available_candidates[0].model + logger.info(f"Only one candidate available, selecting: {selected}") + return selected + + # Check if selector model provider is available + if MODEL_SELECTOR_PROVIDER not in available_types: + # Use highest priority candidate + selected = available_candidates[0].model + logger.warning( + f"Model selector provider ({MODEL_SELECTOR_PROVIDER.value}) not available, " + f"using highest priority: {selected}" + ) + return selected + + # Use LLM to select model + try: + selected = await _llm_select_model( + available_candidates, + first_message, + user_provider_manager, + ) + logger.info(f"LLM selected model: {selected}") + return selected + except Exception as e: + logger.error(f"LLM model selection failed: {e}") + # Fall back to highest priority candidate + selected = ( + available_candidates[0].model + if available_candidates + else fallback.model + if fallback + else get_fallback_model_for_tier(tier).model + ) + logger.warning(f"Falling back to: {selected}") + return selected + + +async def _llm_select_model( + candidates: list[TierModelCandidate], + first_message: str, + user_provider_manager: "ProviderManager", +) -> str: + """Use LLM to select the best model from candidates. + + Args: + candidates: Available model candidates + first_message: User's first message + user_provider_manager: Provider manager for LLM access + + Returns: + Selected model name + + Raises: + Exception: If LLM call fails or returns invalid selection + """ + # Format prompt + available_models_str = _format_available_models_for_prompt(candidates) + logger.info(f"Available models: {available_models_str}") + prompt = MODEL_SELECTION_PROMPT.format( + available_models=available_models_str, + user_message=first_message[:2000], # Truncate long messages + ) + + logger.debug(f"Model selection prompt:\n{prompt}") + + # Create LLM and call + llm = await user_provider_manager.create_langchain_model( + provider_id=MODEL_SELECTOR_PROVIDER, + model=MODEL_SELECTOR_MODEL, + ) + + response = await llm.ainvoke([HumanMessage(content=prompt)]) + logger.debug(f"LLM response: {response}") + + # Parse response + if isinstance(response.content, str): + selected_model = response.content.strip() + + # Validate selection + valid_models = {c.model for c in candidates} + if selected_model in valid_models: + return selected_model + else: + logger.warning(f"LLM selected invalid model: {selected_model}, valid: {valid_models}") + raise ValueError(f"Invalid model selection: {selected_model}") + else: + raise ValueError(f"Unexpected response type: {type(response.content)}") diff --git a/service/app/core/chat/topic_generator.py b/service/app/core/chat/topic_generator.py index f0534add..1b7ff4b3 100644 --- a/service/app/core/chat/topic_generator.py +++ b/service/app/core/chat/topic_generator.py @@ -1,65 +1,45 @@ import json import logging -from typing import TYPE_CHECKING from uuid import UUID from langchain_core.messages import HumanMessage -from app.configs import configs from app.core.providers import get_user_provider_manager from app.infra.database import AsyncSessionLocal from app.models.topic import TopicUpdate -from app.repos.session import SessionRepository from app.repos.topic import TopicRepository -from app.schemas.provider import ProviderType - -if TYPE_CHECKING: - from app.api.ws.v1.chat import ConnectionManager +from app.schemas.model_tier import TOPIC_RENAME_MODEL, TOPIC_RENAME_PROVIDER logger = logging.getLogger(__name__) -def _select_title_generation_model( - *, - provider_type: ProviderType | None, - session_model: str | None, - default_model: str | None, -) -> str | None: - if provider_type == ProviderType.GOOGLE_VERTEX: - return "gemini-2.5-flash" - if provider_type == ProviderType.AZURE_OPENAI: - return "gpt-5-mini" - if provider_type == ProviderType.GPUGEEK: - return "Vendor2/Gemini-2.5-Flash" - if provider_type == ProviderType.QWEN: - return "qwen-flash" - return session_model or default_model - - async def generate_and_update_topic_title( message_text: str, topic_id: UUID, session_id: UUID, user_id: str, - connection_manager: "ConnectionManager", connection_id: str, ) -> None: """ Background task to generate a concise title for a topic based on its content and update it in the database. + This function publishes the update to Redis pub/sub instead of using a local + ConnectionManager, ensuring the update reaches the client regardless of which + pod handles the WebSocket connection. + Args: + message_text: The user's message to generate a title from topic_id: The UUID of the topic to update + session_id: The session UUID (unused but kept for API compatibility) user_id: The user ID (for LLM access) - connection_manager: WebSocket connection manager to broadcast updates - connection_id: The specific WebSocket connection ID to send the update to + connection_id: The WebSocket connection ID (used as Redis channel key) """ logger.info(f"Starting background title generation for topic {topic_id}") async with AsyncSessionLocal() as db: try: topic_repo = TopicRepository(db) - session_repo = SessionRepository(db) topic = await topic_repo.get_topic_by_id(topic_id) if not topic: @@ -68,30 +48,6 @@ async def generate_and_update_topic_title( user_provider_manager = await get_user_provider_manager(user_id, db) - # Prefer the session-selected provider/model; otherwise fall back to system defaults. - session = await session_repo.get_session_by_id(session_id) - provider_id = str(session.provider_id) if session and session.provider_id else None - session_model = session.model if session and session.model else None - - default_cfg = configs.LLM.default_config - default_model = default_cfg.model if default_cfg else None - - provider_type: ProviderType | None = None - if provider_id: - cfg = user_provider_manager.get_provider_config(provider_id) - provider_type = cfg.provider_type if cfg else None - else: - provider_type = configs.LLM.default_provider - - model_name = _select_title_generation_model( - provider_type=provider_type, - session_model=session_model, - default_model=default_model, - ) - - if not model_name: - logger.error("No model configured for title generation") - return prompt = ( "Please generate a short, concise title (3-5 words) based on the following user query. " "Do not use quotes. " @@ -99,7 +55,10 @@ async def generate_and_update_topic_title( f"{message_text}" ) - llm = await user_provider_manager.create_langchain_model(provider_id, model_name) + llm = await user_provider_manager.create_langchain_model( + provider_id=TOPIC_RENAME_PROVIDER, + model=TOPIC_RENAME_MODEL, + ) response = await llm.ainvoke([HumanMessage(content=prompt)]) logger.debug(f"LLM response: {response}") @@ -121,7 +80,14 @@ async def generate_and_update_topic_title( "updated_at": updated_topic.updated_at.isoformat(), }, } - await connection_manager.send_personal_message(json.dumps(event), connection_id) + # Publish to Redis channel for cross-pod delivery + # The redis_listener in chat.py subscribes to this channel + from app.infra.redis import get_redis_client + + r = await get_redis_client() + channel = f"chat:{connection_id}" + await r.publish(channel, json.dumps(event)) + logger.debug(f"Published topic_updated event to Redis channel: {channel}") except Exception as e: logger.error(f"Error in title generation task: {e}") diff --git a/service/app/core/consume.py b/service/app/core/consume.py index c708dc9b..dd33ea0c 100644 --- a/service/app/core/consume.py +++ b/service/app/core/consume.py @@ -49,6 +49,9 @@ async def create_consume_record( input_tokens: int | None = None, output_tokens: int | None = None, total_tokens: int | None = None, + model_tier: str | None = None, + tier_rate: float | None = None, + calculation_breakdown: str | None = None, ) -> ConsumeRecord: """ Create consumption record and execute remote billing (if needed) @@ -67,6 +70,9 @@ async def create_consume_record( input_tokens: Number of input tokens used output_tokens: Number of output tokens generated total_tokens: Total tokens (input + output) + model_tier: Model tier used (ultra/pro/standard/lite) + tier_rate: Tier rate multiplier applied + calculation_breakdown: JSON breakdown of calculation Returns: Created consumption record @@ -109,6 +115,9 @@ async def create_consume_record( input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens, + model_tier=model_tier, + tier_rate=tier_rate, + calculation_breakdown=calculation_breakdown, consume_state=initial_state, ) @@ -327,6 +336,9 @@ async def create_consume_for_chat( input_tokens: int | None = None, output_tokens: int | None = None, total_tokens: int | None = None, + model_tier: str | None = None, + tier_rate: float | None = None, + calculation_breakdown: str | None = None, ) -> ConsumeRecord: """ Convenience function to create consumption record for chat @@ -344,6 +356,9 @@ async def create_consume_for_chat( input_tokens: Number of input tokens used output_tokens: Number of output tokens generated total_tokens: Total tokens (input + output) + model_tier: Model tier used (ultra/pro/standard/lite) + tier_rate: Tier rate multiplier applied + calculation_breakdown: JSON breakdown of calculation Returns: Consumption record @@ -361,4 +376,7 @@ async def create_consume_for_chat( input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens, + model_tier=model_tier, + tier_rate=tier_rate, + calculation_breakdown=calculation_breakdown, ) diff --git a/service/app/core/consume_calculator.py b/service/app/core/consume_calculator.py new file mode 100644 index 00000000..6881423d --- /dev/null +++ b/service/app/core/consume_calculator.py @@ -0,0 +1,59 @@ +"""Consumption calculator with strategy pattern. + +This module provides a factory/manager for consumption calculation strategies, +allowing easy switching between different pricing approaches. +""" + +from app.core.consume_strategy import ( + ConsumptionContext, + ConsumptionResult, + ConsumptionStrategy, + TierBasedConsumptionStrategy, +) + + +class ConsumptionCalculator: + """Factory and executor for consumption strategies. + + This class manages strategy instances and provides a unified interface + for calculating consumption amounts. + + Usage: + context = ConsumptionContext( + model_tier=ModelTier.PRO, + input_tokens=1000, + output_tokens=500, + ) + result = ConsumptionCalculator.calculate(context) + print(f"Amount: {result.amount}, Breakdown: {result.breakdown}") + """ + + _strategy: ConsumptionStrategy = TierBasedConsumptionStrategy() + + @classmethod + def calculate(cls, context: ConsumptionContext) -> ConsumptionResult: + """Calculate consumption amount using tier-based strategy. + + Args: + context: ConsumptionContext with all relevant information + + Returns: + ConsumptionResult with amount and breakdown + """ + return cls._strategy.calculate(context) + + @classmethod + def set_strategy(cls, strategy: ConsumptionStrategy) -> None: + """Set the consumption strategy. + + Useful for testing or runtime extension. + + Args: + strategy: Strategy instance to use + """ + cls._strategy = strategy + + @classmethod + def reset_strategy(cls) -> None: + """Reset to default tier-based strategy.""" + cls._strategy = TierBasedConsumptionStrategy() diff --git a/service/app/core/consume_strategy.py b/service/app/core/consume_strategy.py new file mode 100644 index 00000000..7d3abc99 --- /dev/null +++ b/service/app/core/consume_strategy.py @@ -0,0 +1,118 @@ +"""Consumption calculation strategies. + +This module defines the strategy pattern for consumption calculation, +allowing extensible and configurable pricing strategies. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from app.schemas.model_tier import TIER_MODEL_CONSUMPTION_RATE, ModelTier + + +@dataclass +class ConsumptionContext: + """Context for consumption calculation. + + This dataclass holds all information needed to calculate consumption. + Extensible: add more fields as pricing needs evolve. + """ + + model_tier: ModelTier | None = None + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 + content_length: int = 0 + generated_files_count: int = 0 + + +@dataclass +class ConsumptionResult: + """Result of consumption calculation. + + Attributes: + amount: Final consumption amount (integer points) + breakdown: Detailed breakdown of calculation for transparency/debugging + """ + + amount: int + breakdown: dict[str, Any] = field(default_factory=dict) + + +class ConsumptionStrategy(ABC): + """Abstract base for consumption calculation strategies. + + Implement this interface to create new pricing strategies. + """ + + @abstractmethod + def calculate(self, context: ConsumptionContext) -> ConsumptionResult: + """Calculate consumption amount based on context. + + Args: + context: ConsumptionContext with all relevant information + + Returns: + ConsumptionResult with amount and breakdown + """ + pass + + +class TierBasedConsumptionStrategy(ConsumptionStrategy): + """Calculate consumption using tier multipliers. + + Design decisions: + - LITE tier (rate 0.0) = completely free + - Tier rate multiplies ALL costs (base + tokens + files) + """ + + BASE_COST = 3 + INPUT_TOKEN_RATE = 0.2 / 1000 # per token + OUTPUT_TOKEN_RATE = 1 / 1000 # per token + FILE_GENERATION_COST = 10 + + def calculate(self, context: ConsumptionContext) -> ConsumptionResult: + """Calculate consumption with tier-based multiplier. + + Args: + context: ConsumptionContext with tier and usage information + + Returns: + ConsumptionResult with tier-adjusted amount + """ + tier_rate = TIER_MODEL_CONSUMPTION_RATE.get(context.model_tier, 1.0) if context.model_tier else 1.0 + + # LITE tier (rate 0.0) = completely free + if tier_rate == 0.0: + return ConsumptionResult( + amount=0, + breakdown={ + "base_cost": 0, + "token_cost": 0, + "file_cost": 0, + "tier_rate": 0.0, + "tier": context.model_tier.value if context.model_tier else "lite", + "note": "LITE tier - free usage", + }, + ) + + # Calculate base token cost + token_cost = context.input_tokens * self.INPUT_TOKEN_RATE + context.output_tokens * self.OUTPUT_TOKEN_RATE + file_cost = context.generated_files_count * self.FILE_GENERATION_COST + + # Tier rate multiplies ALL costs + base_amount = self.BASE_COST + token_cost + file_cost + final_amount = int(base_amount * tier_rate) + + return ConsumptionResult( + amount=final_amount, + breakdown={ + "base_cost": self.BASE_COST, + "token_cost": token_cost, + "file_cost": file_cost, + "pre_multiplier_total": base_amount, + "tier_rate": tier_rate, + "tier": context.model_tier.value if context.model_tier else "default", + }, + ) diff --git a/service/app/core/model_registry/service.py b/service/app/core/model_registry/service.py index 2a3bb40a..1b02d209 100644 --- a/service/app/core/model_registry/service.py +++ b/service/app/core/model_registry/service.py @@ -6,10 +6,12 @@ This service replaces LiteLLM-based model information with models.dev data. """ +import json import logging import re import time from dataclasses import dataclass, field +from typing import Any import httpx @@ -116,36 +118,70 @@ class ModelsDevService: Features: - Async HTTP fetching from https://models.dev/api.json - - In-memory caching with configurable TTL + - Redis caching for multi-pod deployments (with in-memory fallback) - Model lookup by ID and provider - Conversion to LiteLLM-compatible ModelInfo format """ API_URL = "https://models.dev/api.json" CACHE_TTL = 3600 # 1 hour in seconds + CACHE_KEY = "models:dev:api" - _cache: ModelsDevResponse | None = None - _cache_time: float = 0 + # In-memory fallback cache (used when Redis is unavailable or disabled) + _local_cache: ModelsDevResponse | None = None + _local_cache_time: float = 0 @classmethod - def _is_cache_valid(cls) -> bool: - """Check if the cache is still valid based on TTL.""" - if cls._cache is None: + async def _get_redis(cls) -> Any | None: + """Get Redis client if available.""" + try: + from app.configs import configs + + if configs.Redis.CacheBackend != "redis": + return None + + from app.infra.redis import get_redis_client + + return await get_redis_client() + except Exception as e: + logger.debug(f"Redis not available for models cache: {e}") + return None + + @classmethod + def _is_local_cache_valid(cls) -> bool: + """Check if the local cache is still valid based on TTL.""" + if cls._local_cache is None: return False - return (time.time() - cls._cache_time) < cls.CACHE_TTL + return (time.time() - cls._local_cache_time) < cls.CACHE_TTL @classmethod async def fetch_data(cls) -> ModelsDevResponse: """ Fetch model data from models.dev API with caching. + Uses Redis cache for multi-pod deployments, falls back to local cache. + Returns: Dictionary mapping provider ID to ModelsDevProvider """ - if cls._is_cache_valid() and cls._cache is not None: - logger.debug("Using cached models.dev data") - return cls._cache + # Try Redis cache first + redis_client = await cls._get_redis() + if redis_client: + try: + cached_data = await redis_client.get(cls.CACHE_KEY) + if cached_data: + logger.debug("Using Redis cached models.dev data") + raw_data = json.loads(cached_data) + return cls._parse_raw_data(raw_data) + except Exception as e: + logger.warning(f"Redis cache read failed: {e}") + + # Try local cache + if cls._is_local_cache_valid() and cls._local_cache is not None: + logger.debug("Using local cached models.dev data") + return cls._local_cache + # Fetch fresh data logger.info("Fetching fresh data from models.dev API") try: async with httpx.AsyncClient(timeout=30.0) as client: @@ -153,29 +189,43 @@ async def fetch_data(cls) -> ModelsDevResponse: response.raise_for_status() raw_data = response.json() - # Parse the response into typed models - parsed: ModelsDevResponse = {} - for provider_id, provider_data in raw_data.items(): + parsed = cls._parse_raw_data(raw_data) + + # Update Redis cache + if redis_client: try: - parsed[provider_id] = ModelsDevProvider.model_validate(provider_data) + await redis_client.setex(cls.CACHE_KEY, cls.CACHE_TTL, json.dumps(raw_data)) + logger.info(f"Cached {len(parsed)} providers in Redis") except Exception as e: - logger.warning(f"Failed to parse provider {provider_id}: {e}") - continue + logger.warning(f"Redis cache write failed: {e}") - # Update cache - cls._cache = parsed - cls._cache_time = time.time() + # Update local cache as fallback + cls._local_cache = parsed + cls._local_cache_time = time.time() logger.info(f"Cached {len(parsed)} providers from models.dev") + return parsed except httpx.HTTPError as e: logger.error(f"Failed to fetch from models.dev API: {e}") - # Return cached data if available, even if expired - if cls._cache is not None: - logger.warning("Returning stale cached data due to fetch error") - return cls._cache + # Return any cached data if available, even if expired + if cls._local_cache is not None: + logger.warning("Returning stale local cached data due to fetch error") + return cls._local_cache raise + @classmethod + def _parse_raw_data(cls, raw_data: dict) -> ModelsDevResponse: + """Parse raw API response into typed models.""" + parsed: ModelsDevResponse = {} + for provider_id, provider_data in raw_data.items(): + try: + parsed[provider_id] = ModelsDevProvider.model_validate(provider_data) + except Exception as e: + logger.warning(f"Failed to parse provider {provider_id}: {e}") + continue + return parsed + @classmethod async def get_model_info( cls, @@ -626,8 +676,18 @@ async def list_all_models(cls) -> list[str]: return model_ids @classmethod - def clear_cache(cls) -> None: - """Clear the in-memory cache.""" - cls._cache = None - cls._cache_time = 0 - logger.info("models.dev cache cleared") + async def clear_cache(cls) -> None: + """Clear both Redis and local caches.""" + # Clear Redis cache + redis_client = await cls._get_redis() + if redis_client: + try: + await redis_client.delete(cls.CACHE_KEY) + logger.info("Redis models.dev cache cleared") + except Exception as e: + logger.warning(f"Failed to clear Redis cache: {e}") + + # Clear local cache + cls._local_cache = None + cls._local_cache_time = 0 + logger.info("Local models.dev cache cleared") diff --git a/service/app/core/providers/manager.py b/service/app/core/providers/manager.py index 8a5a7891..6419f248 100644 --- a/service/app/core/providers/manager.py +++ b/service/app/core/providers/manager.py @@ -71,13 +71,15 @@ def remove_provider(self, name: str) -> None: logger.info(f"Removed provider '{name}'") async def create_langchain_model( - self, provider_id: str | None = None, model: str | None = None, **override_kwargs: Any + self, provider_id: str | ProviderType | None = None, model: str | None = None, **override_kwargs: Any ) -> BaseChatModel: """ Create a LangChain model using the stored config and the ChatModelFactory. Args: - provider_id: The provider ID (UUID string). If None, uses system provider as fallback. + provider_id: The provider ID (UUID string, system alias, or ProviderType enum). + If ProviderType is passed, it will be converted to system: format. + If None, uses system provider as fallback. model: The model name to use. Required. override_kwargs: Runtime overrides (e.g. temperature, max_tokens) @@ -91,6 +93,10 @@ async def create_langchain_model( if not model: raise ErrCode.MODEL_NOT_SPECIFIED.with_messages("Model must be specified") + # Convert ProviderType enum to system alias format + if isinstance(provider_id, ProviderType): + provider_id = f"{SYSTEM_PROVIDER_NAME}:{provider_id.value}" + async def infer_provider_preference(model_name: str) -> list[ProviderType]: """Infer likely provider type(s) for a model. @@ -164,22 +170,24 @@ async def infer_provider_preference(model_name: str) -> list[ProviderType]: return model_instance.llm -user_provider_managers: dict[str, ProviderManager] = {} - - async def get_user_provider_manager(user_id: str, db: AsyncSession) -> ProviderManager: """ - Create a provider manager with all providers for a specific user. - """ - if user_id in user_provider_managers: - return user_provider_managers[user_id] + Create a provider manager with system providers. + Note: User-defined providers are disabled. This function now only loads + system providers configured via environment variables. + + This function rebuilds the ProviderManager from the database on each call + to support stateless multi-pod deployments. Since all users share the same + system providers, the overhead is minimal (~5ms with connection pooling). + """ from app.repos.provider import ProviderRepository provider_repo = ProviderRepository(db) - all_providers = await provider_repo.get_providers_by_user(user_id, include_system=True) + # Only load system providers - user-defined providers are disabled + all_providers = await provider_repo.get_all_system_providers() if not all_providers: - raise ErrCode.PROVIDER_NOT_FOUND.with_messages("No providers found for user") + raise ErrCode.PROVIDER_NOT_FOUND.with_messages("No system providers configured") from app.configs import configs @@ -243,6 +251,4 @@ async def get_user_provider_manager(user_id: str, db: AsyncSession) -> ProviderM logger.error(f"Failed to load provider {db_provider.name} for user {user_id}: {e}") continue - user_provider_managers[user_id] = user_provider_manager - return user_provider_manager diff --git a/service/app/core/websocket.py b/service/app/core/websocket.py index 282733b5..514815cc 100644 --- a/service/app/core/websocket.py +++ b/service/app/core/websocket.py @@ -2,13 +2,20 @@ WebSocket connection management for broadcasting updates. This module provides a centralized way to manage WebSocket connections and broadcast messages to connected clients. + +For multi-pod deployments, broadcasts are sent via Redis pub/sub to ensure +all connected clients receive updates regardless of which pod they're connected to. """ +import asyncio import json +import logging from typing import Any from fastapi import WebSocket +logger = logging.getLogger(__name__) + class ConnectionManager: """Manages active WebSocket connections.""" @@ -45,5 +52,101 @@ async def broadcast(self, data: Any) -> None: self.disconnect(connection) +class RedisBroadcastManager(ConnectionManager): + """ + WebSocket connection manager with Redis pub/sub support for multi-pod deployments. + + - Local connections are managed in-memory (same as ConnectionManager) + - Broadcasts are published to Redis and received by all pods + - Each pod subscribes to the Redis channel and forwards to local connections + """ + + def __init__(self, channel: str) -> None: + super().__init__() + self.channel = channel + self._subscriber_task: asyncio.Task | None = None + self._redis: Any = None + + async def _get_redis(self) -> Any: + """Get async Redis client.""" + if self._redis is None: + from app.infra.redis import get_redis_client + + self._redis = await get_redis_client() + return self._redis + + async def start_subscriber(self) -> None: + """Start the Redis subscriber task to receive broadcasts from other pods.""" + if self._subscriber_task is not None: + return + + async def _subscriber_loop() -> None: + import redis.asyncio as redis + + from app.configs import configs + + r = redis.from_url(configs.Redis.REDIS_URL, decode_responses=True) + pubsub = r.pubsub() + await pubsub.subscribe(self.channel) + logger.info(f"MCP WebSocket manager subscribed to Redis channel: {self.channel}") + + try: + async for message in pubsub.listen(): + if message["type"] == "message": + # Forward the message to all local connections + data = message["data"] + await self._broadcast_local(data) + except asyncio.CancelledError: + logger.info(f"Redis subscriber cancelled for channel: {self.channel}") + except Exception as e: + logger.error(f"Redis subscriber error: {e}") + finally: + await pubsub.unsubscribe(self.channel) + await r.close() + + self._subscriber_task = asyncio.create_task(_subscriber_loop()) + + async def stop_subscriber(self) -> None: + """Stop the Redis subscriber task.""" + if self._subscriber_task is not None: + self._subscriber_task.cancel() + try: + await self._subscriber_task + except asyncio.CancelledError: + pass + self._subscriber_task = None + + async def _broadcast_local(self, message: str) -> None: + """Broadcast a message to local connections only.""" + if not self.active_connections: + return + + connections = self.active_connections.copy() + for connection in connections: + try: + await connection.send_text(message) + except Exception: + self.disconnect(connection) + + async def broadcast(self, data: Any) -> None: + """ + Broadcast a message to all connected clients across all pods. + + Publishes to Redis, which will be received by all pods + (including this one via the subscriber). + """ + message = json.dumps(data, default=str) + + try: + redis_client = await self._get_redis() + await redis_client.publish(self.channel, message) + logger.debug(f"Published MCP status update to Redis channel: {self.channel}") + except Exception as e: + logger.error(f"Failed to publish to Redis, falling back to local broadcast: {e}") + # Fallback to local broadcast if Redis fails + await self._broadcast_local(message) + + # Global instance for MCP server status broadcasts -mcp_websocket_manager = ConnectionManager() +# Use Redis-backed manager for multi-pod support +mcp_websocket_manager = RedisBroadcastManager(channel="mcp:status") diff --git a/service/app/infra/redis/__init__.py b/service/app/infra/redis/__init__.py new file mode 100644 index 00000000..b69473be --- /dev/null +++ b/service/app/infra/redis/__init__.py @@ -0,0 +1,92 @@ +""" +Redis client infrastructure for distributed caching. + +This module provides an async Redis client singleton that can be used +throughout the application for caching and pub/sub operations. +""" + +import logging +from collections.abc import AsyncGenerator +from enum import Enum + +import redis.asyncio as redis + +from app.configs import configs + +logger = logging.getLogger(__name__) + + +class CacheBackend(str, Enum): + """Cache backend options.""" + + LOCAL = "local" + REDIS = "redis" + + +# Global Redis client instance +_redis_client: redis.Redis | None = None + + +async def get_redis_client() -> redis.Redis: + """ + Get the global async Redis client instance. + + Creates a new connection on first call, reuses existing connection + on subsequent calls. + + Returns: + redis.Redis: Async Redis client instance + """ + global _redis_client + if _redis_client is None: + _redis_client = redis.from_url( + configs.Redis.REDIS_URL, + decode_responses=True, + ) + logger.info(f"Redis client initialized: {configs.Redis.HOST}:{configs.Redis.PORT}") + return _redis_client + + +async def get_redis_dependency() -> AsyncGenerator[redis.Redis, None]: + """ + FastAPI dependency for Redis client. + + Yields: + redis.Redis: Async Redis client instance + """ + client = await get_redis_client() + yield client + + +async def close_redis_client() -> None: + """Close the global Redis client connection.""" + global _redis_client + if _redis_client is not None: + await _redis_client.close() + _redis_client = None + logger.info("Redis client connection closed") + + +async def health_check() -> bool: + """ + Check Redis connectivity. + + Returns: + bool: True if Redis is reachable, False otherwise + """ + try: + client = await get_redis_client() + await client.ping() + return True + except Exception as e: + logger.error(f"Redis health check failed: {e}") + return False + + +__all__ = [ + "CacheBackend", + "get_redis_client", + "get_redis_dependency", + "close_redis_client", + "health_check", +] diff --git a/service/app/middleware/auth/cache.py b/service/app/middleware/auth/cache.py index cb646650..c9307a37 100644 --- a/service/app/middleware/auth/cache.py +++ b/service/app/middleware/auth/cache.py @@ -1,18 +1,72 @@ """ Token 缓存服务,减少重复的认证服务商调用 + +Supports both local in-memory cache (single pod) and Redis cache (multi-pod). """ import asyncio +import hashlib +import json import logging -from dataclasses import dataclass, field +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, field from datetime import datetime, timedelta from typing import Any -from . import AuthResult +from . import AuthResult, UserInfo logger = logging.getLogger(__name__) +def _auth_result_to_json(result: AuthResult) -> str: + """Serialize AuthResult to JSON string.""" + data = asdict(result) + return json.dumps(data) + + +def _json_to_auth_result(json_str: str) -> AuthResult: + """Deserialize AuthResult from JSON string.""" + data = json.loads(json_str) + user_info = None + if data.get("user_info"): + user_info = UserInfo(**data["user_info"]) + return AuthResult( + success=data["success"], + user_info=user_info, + error_message=data.get("error_message"), + error_code=data.get("error_code"), + ) + + +class BaseTokenCache(ABC): + """Abstract base class for token cache implementations.""" + + @abstractmethod + async def get(self, token: str, provider: str) -> AuthResult | None: + """Get cached authentication result.""" + pass + + @abstractmethod + async def set(self, token: str, provider: str, auth_result: AuthResult, ttl_minutes: int | None = None) -> None: + """Set cached authentication result.""" + pass + + @abstractmethod + async def invalidate(self, token: str, provider: str) -> None: + """Invalidate cached authentication result.""" + pass + + @abstractmethod + async def clear(self) -> None: + """Clear all cached entries.""" + pass + + @abstractmethod + def get_stats(self) -> dict[str, Any]: + """Get cache statistics.""" + pass + + @dataclass class CachedAuthResult: """缓存的认证结果""" @@ -29,8 +83,8 @@ def is_expired(self) -> bool: return datetime.now() >= self.cached_at + timedelta(minutes=5) -class TokenCache: - """Token 缓存管理器""" +class TokenCache(BaseTokenCache): + """Token 缓存管理器 (In-memory implementation for single pod)""" def __init__(self, default_ttl_minutes: int = 5, max_size: int = 1000): self.default_ttl_minutes = default_ttl_minutes @@ -43,8 +97,6 @@ def __init__(self, default_ttl_minutes: int = 5, max_size: int = 1000): def _get_cache_key(self, token: str, provider: str) -> str: """生成缓存键,使用token的hash而不是完整token来节省内存""" - import hashlib - token_hash = hashlib.sha256(token.encode()).hexdigest()[:16] return f"{provider}:{token_hash}" @@ -134,16 +186,123 @@ def get_stats(self) -> dict[str, Any]: "cache_size": len(self._cache), "max_size": self.max_size, "default_ttl_minutes": self.default_ttl_minutes, + "backend": "local", + } + + +class RedisTokenCache(BaseTokenCache): + """Token cache backed by Redis for multi-pod deployments.""" + + CACHE_PREFIX = "auth:token:" + + def __init__(self, default_ttl_minutes: int = 5): + self.default_ttl_minutes = default_ttl_minutes + self._redis: Any = None # Lazy initialization + + async def _get_redis(self) -> Any: + """Get Redis client lazily.""" + if self._redis is None: + from app.infra.redis import get_redis_client + + self._redis = await get_redis_client() + return self._redis + + def _get_cache_key(self, token: str, provider: str) -> str: + """Generate cache key using hashed token.""" + token_hash = hashlib.sha256(token.encode()).hexdigest()[:16] + return f"{self.CACHE_PREFIX}{provider}:{token_hash}" + + async def get(self, token: str, provider: str) -> AuthResult | None: + """Get cached authentication result from Redis.""" + cache_key = self._get_cache_key(token, provider) + + try: + redis = await self._get_redis() + data = await redis.get(cache_key) + + if not data: + return None + + logger.debug(f"Redis token cache hit for key: {cache_key}") + return _json_to_auth_result(data) + + except Exception as e: + logger.error(f"Redis token cache get error: {e}") + return None + + async def set(self, token: str, provider: str, auth_result: AuthResult, ttl_minutes: int | None = None) -> None: + """Set cached authentication result in Redis.""" + if not auth_result.success: + return + + cache_key = self._get_cache_key(token, provider) + ttl_seconds = (ttl_minutes or self.default_ttl_minutes) * 60 + + try: + redis = await self._get_redis() + json_data = _auth_result_to_json(auth_result) + await redis.setex(cache_key, ttl_seconds, json_data) + logger.debug(f"Redis token cached for key: {cache_key}, TTL: {ttl_seconds}s") + + except Exception as e: + logger.error(f"Redis token cache set error: {e}") + + async def invalidate(self, token: str, provider: str) -> None: + """Invalidate cached authentication result in Redis.""" + cache_key = self._get_cache_key(token, provider) + + try: + redis = await self._get_redis() + await redis.delete(cache_key) + logger.debug(f"Redis token cache invalidated for key: {cache_key}") + + except Exception as e: + logger.error(f"Redis token cache invalidate error: {e}") + + async def clear(self) -> None: + """Clear all token cache entries in Redis.""" + try: + redis = await self._get_redis() + # Use SCAN to find all keys with our prefix + cursor = 0 + while True: + cursor, keys = await redis.scan(cursor, match=f"{self.CACHE_PREFIX}*", count=100) + if keys: + await redis.delete(*keys) + if cursor == 0: + break + logger.debug("Redis token cache cleared") + + except Exception as e: + logger.error(f"Redis token cache clear error: {e}") + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics.""" + return { + "default_ttl_minutes": self.default_ttl_minutes, + "backend": "redis", } # 全局缓存实例 -_token_cache: TokenCache | None = None +_token_cache: BaseTokenCache | None = None + +def get_token_cache() -> BaseTokenCache: + """获取全局token缓存实例 -def get_token_cache() -> TokenCache: - """获取全局token缓存实例""" + Returns the appropriate cache backend based on configuration: + - "local": In-memory cache (default, for single pod) + - "redis": Redis-backed cache (for multi-pod deployments) + """ global _token_cache if _token_cache is None: - _token_cache = TokenCache() + from app.configs import configs + + if configs.Redis.CacheBackend == "redis": + logger.info("Using Redis token cache backend") + _token_cache = RedisTokenCache() + else: + logger.info("Using local in-memory token cache backend") + _token_cache = TokenCache() return _token_cache diff --git a/service/app/middleware/auth/simple_cache.py b/service/app/middleware/auth/simple_cache.py index 6b4e8bca..74405d24 100644 --- a/service/app/middleware/auth/simple_cache.py +++ b/service/app/middleware/auth/simple_cache.py @@ -1,19 +1,71 @@ """ 简化的token缓存装饰器,用于优化认证性能 + +Supports both local in-memory cache (single pod) and Redis cache (multi-pod). +This is a synchronous cache used by sync auth provider methods. """ +import hashlib +import json import logging import time +from abc import ABC, abstractmethod from functools import wraps from typing import Any, Callable -from . import AuthResult +from . import AuthResult, UserInfo logger = logging.getLogger(__name__) -class SimpleTokenCache: - """简单的内存token缓存""" +def _auth_result_to_json(result: AuthResult) -> str: + """Serialize AuthResult to JSON string.""" + from dataclasses import asdict + + data = asdict(result) + return json.dumps(data) + + +def _json_to_auth_result(json_str: str) -> AuthResult: + """Deserialize AuthResult from JSON string.""" + data = json.loads(json_str) + user_info = None + if data.get("user_info"): + user_info = UserInfo(**data["user_info"]) + return AuthResult( + success=data["success"], + user_info=user_info, + error_message=data.get("error_message"), + error_code=data.get("error_code"), + ) + + +class BaseSimpleTokenCache(ABC): + """Abstract base class for simple token cache implementations.""" + + @abstractmethod + def get(self, token: str, provider: str) -> AuthResult | None: + """Get cached authentication result.""" + pass + + @abstractmethod + def set(self, token: str, provider: str, auth_result: AuthResult) -> None: + """Set cached authentication result.""" + pass + + @abstractmethod + def clear(self) -> None: + """Clear all cached entries.""" + pass + + @abstractmethod + def get_stats(self) -> dict[str, Any]: + """Get cache statistics.""" + pass + + +class SimpleTokenCache(BaseSimpleTokenCache): + """简单的内存token缓存 (for single pod deployments)""" def __init__(self, ttl_seconds: int = 300): # 默认5分钟TTL self.ttl_seconds = ttl_seconds @@ -22,8 +74,6 @@ def __init__(self, ttl_seconds: int = 300): # 默认5分钟TTL def _get_cache_key(self, token: str, provider: str) -> str: """生成缓存键""" - import hashlib - token_hash = hashlib.sha256(token.encode()).hexdigest()[:16] return f"{provider}:{token_hash}" @@ -87,18 +137,123 @@ def clear(self) -> None: def get_stats(self) -> dict[str, Any]: """获取缓存统计""" - return {"cache_size": len(self._cache), "max_size": self._max_size, "ttl_seconds": self.ttl_seconds} + return { + "cache_size": len(self._cache), + "max_size": self._max_size, + "ttl_seconds": self.ttl_seconds, + "backend": "local", + } + + +class RedisSimpleTokenCache(BaseSimpleTokenCache): + """Redis-backed simple token cache for multi-pod deployments. + + Uses synchronous Redis operations since this cache is used in sync contexts. + """ + + CACHE_PREFIX = "auth:simple_token:" + + def __init__(self, ttl_seconds: int = 300): + self.ttl_seconds = ttl_seconds + self._redis: Any = None + + def _get_redis(self) -> Any: + """Get synchronous Redis client.""" + if self._redis is None: + import redis + + from app.configs import configs + + self._redis = redis.from_url( + configs.Redis.REDIS_URL, + decode_responses=True, + ) + return self._redis + + def _get_cache_key(self, token: str, provider: str) -> str: + """Generate cache key using hashed token.""" + token_hash = hashlib.sha256(token.encode()).hexdigest()[:16] + return f"{self.CACHE_PREFIX}{provider}:{token_hash}" + + def get(self, token: str, provider: str) -> AuthResult | None: + """Get cached authentication result from Redis.""" + cache_key = self._get_cache_key(token, provider) + + try: + redis_client = self._get_redis() + data = redis_client.get(cache_key) + + if not data: + return None + + logger.debug(f"Redis simple token cache hit for key: {cache_key}") + return _json_to_auth_result(data) + + except Exception as e: + logger.error(f"Redis simple token cache get error: {e}") + return None + + def set(self, token: str, provider: str, auth_result: AuthResult) -> None: + """Set cached authentication result in Redis.""" + if not auth_result.success: + return + + cache_key = self._get_cache_key(token, provider) + + try: + redis_client = self._get_redis() + json_data = _auth_result_to_json(auth_result) + redis_client.setex(cache_key, self.ttl_seconds, json_data) + logger.debug(f"Redis simple token cached for key: {cache_key}, TTL: {self.ttl_seconds}s") + + except Exception as e: + logger.error(f"Redis simple token cache set error: {e}") + + def clear(self) -> None: + """Clear all simple token cache entries in Redis.""" + try: + redis_client = self._get_redis() + cursor = 0 + while True: + cursor, keys = redis_client.scan(cursor, match=f"{self.CACHE_PREFIX}*", count=100) + if keys: + redis_client.delete(*keys) + if cursor == 0: + break + logger.debug("Redis simple token cache cleared") + + except Exception as e: + logger.error(f"Redis simple token cache clear error: {e}") + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics.""" + return { + "ttl_seconds": self.ttl_seconds, + "backend": "redis", + } # 全局缓存实例 -_simple_token_cache: SimpleTokenCache | None = None +_simple_token_cache: BaseSimpleTokenCache | None = None + +def get_simple_token_cache() -> BaseSimpleTokenCache: + """获取全局简单token缓存实例 -def get_simple_token_cache() -> SimpleTokenCache: - """获取全局简单token缓存实例""" + Returns the appropriate cache backend based on configuration: + - "local": In-memory cache (default, for single pod) + - "redis": Redis-backed cache (for multi-pod deployments) + """ global _simple_token_cache if _simple_token_cache is None: - _simple_token_cache = SimpleTokenCache() + from app.configs import configs + + if configs.Redis.CacheBackend == "redis": + logger.info("Using Redis simple token cache backend") + _simple_token_cache = RedisSimpleTokenCache() + else: + logger.info("Using local in-memory simple token cache backend") + _simple_token_cache = SimpleTokenCache() return _simple_token_cache diff --git a/service/app/models/consume.py b/service/app/models/consume.py index 29b6c2dc..35e050b3 100644 --- a/service/app/models/consume.py +++ b/service/app/models/consume.py @@ -25,6 +25,11 @@ class ConsumeRecordBase(SQLModel): output_tokens: int | None = Field(default=None, description="Number of output tokens generated") total_tokens: int | None = Field(default=None, description="Total tokens (input + output)") + # Tier-based pricing + model_tier: str | None = Field(default=None, description="Model tier used (ultra/pro/standard/lite)") + tier_rate: float | None = Field(default=None, description="Tier rate multiplier applied") + calculation_breakdown: str | None = Field(default=None, description="JSON breakdown of calculation") + # Billing status consume_state: str = Field(default="pending", description="Consumption state: pending/success/failed") remote_error: str | None = Field(default=None, description="Remote billing error information") @@ -81,6 +86,9 @@ class ConsumeRecordUpdate(SQLModel): input_tokens: int | None = Field(default=None, description="Number of input tokens used") output_tokens: int | None = Field(default=None, description="Number of output tokens generated") total_tokens: int | None = Field(default=None, description="Total tokens (input + output)") + model_tier: str | None = Field(default=None, description="Model tier used (ultra/pro/standard/lite)") + tier_rate: float | None = Field(default=None, description="Tier rate multiplier applied") + calculation_breakdown: str | None = Field(default=None, description="JSON breakdown of calculation") consume_state: str | None = Field(default=None, description="Consumption state: pending/success/failed") remote_error: str | None = Field(default=None, description="Remote billing error information") remote_response: str | None = Field(default=None, description="Remote billing response") diff --git a/service/app/models/sessions.py b/service/app/models/sessions.py index 67cad25b..7636a228 100644 --- a/service/app/models/sessions.py +++ b/service/app/models/sessions.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from .topic import TopicRead +from app.schemas.model_tier import ModelTier + def builtin_agent_id_to_uuid(agent_id: str) -> UUID: """ @@ -73,6 +75,7 @@ class SessionBase(SQLModel): user_id: str = Field(index=True) provider_id: UUID | None = Field(default=None, description="If set, overrides the agent's provider") model: str | None = Field(default=None, description="If set, overrides the agent's model") + model_tier: ModelTier | None = Field(default=None, description="User-selected model tier for simplified selection") google_search_enabled: bool = Field( default=False, description="Enable built-in web search for supported models (e.g., Gemini)" ) @@ -97,6 +100,7 @@ class SessionCreate(SQLModel): agent_id: str | UUID | None = Field(default=None) provider_id: UUID | None = None model: str | None = None + model_tier: ModelTier | None = None google_search_enabled: bool = False @@ -119,4 +123,5 @@ class SessionUpdate(SQLModel): is_active: bool | None = None provider_id: UUID | None = None model: str | None = None + model_tier: ModelTier | None = None google_search_enabled: bool | None = None diff --git a/service/app/repos/provider.py b/service/app/repos/provider.py index a50679ac..b30dd5a1 100644 --- a/service/app/repos/provider.py +++ b/service/app/repos/provider.py @@ -67,6 +67,20 @@ async def get_system_provider(self) -> Provider | None: logger.debug(f"Found system provider: {provider.name}") return provider + async def get_all_system_providers(self) -> list[Provider]: + """ + Fetches all system providers. + + Returns: + List of all system Provider instances. + """ + logger.debug("Fetching all system providers") + statement = select(Provider).where(Provider.scope == ProviderScope.SYSTEM) + result = await self.db.exec(statement) + providers = list(result.all()) + logger.debug(f"Found {len(providers)} system providers") + return providers + async def get_system_provider_by_type(self, provider_type: ProviderType) -> Provider | None: """Fetch a system provider for a specific ProviderType.""" logger.debug(f"Fetching system provider for type: {provider_type}") diff --git a/service/app/repos/session.py b/service/app/repos/session.py index 33c26fb3..a9b08e0e 100644 --- a/service/app/repos/session.py +++ b/service/app/repos/session.py @@ -102,14 +102,17 @@ async def create_session(self, session_data: SessionCreate, user_id: str) -> Ses await self.db.refresh(session) return session - async def update_session(self, session_id: UUID, session_data: SessionUpdate) -> SessionModel | None: + async def update_session(self, session_id: UUID, session_update: SessionUpdate) -> SessionModel | None: """ Updates an existing session. This function does NOT commit the transaction. + When model_tier is changed, clears session.model to trigger re-selection + on the next message. + Args: session_id: The UUID of the session to update. - session_data: The Pydantic model containing the update data. + session_update: The Pydantic model containing the update data. Returns: The updated SessionModel instance, or None if not found. @@ -119,9 +122,22 @@ async def update_session(self, session_id: UUID, session_data: SessionUpdate) -> if not session: return None + # Check if model_tier is being changed + update_data = session_update.model_dump(exclude_unset=True) + if "model_tier" in update_data: + new_tier = update_data.get("model_tier") + if new_tier != session.model_tier: + # Clear session.model to trigger re-selection on next message + logger.info( + f"Session {session_id}: model_tier changed from {session.model_tier} to {new_tier}, " + f"clearing model to trigger re-selection" + ) + session.model = None + # Only update fields that are not None to avoid null constraint violations - update_data = session_data.model_dump(exclude_unset=True, exclude_none=True) - for field, value in update_data.items(): + # But we already handled model clearing above for tier changes + update_data_filtered = session_update.model_dump(exclude_unset=True, exclude_none=True) + for field, value in update_data_filtered.items(): if hasattr(session, field): setattr(session, field, value) diff --git a/service/app/schemas/model_tier.py b/service/app/schemas/model_tier.py new file mode 100644 index 00000000..6c7a9931 --- /dev/null +++ b/service/app/schemas/model_tier.py @@ -0,0 +1,218 @@ +"""Model tier definitions for intelligent model selection. + +Users select from 4 tiers instead of specific models. The backend +uses an LLM to intelligently select the best model for the task +from available candidates in the tier. +""" + +from dataclasses import dataclass, field +from enum import Enum + +from app.schemas.provider import ProviderType + + +class ModelTier(str, Enum): + """User-facing model tiers for simplified selection.""" + + ULTRA = "ultra" # Complex reasoning, research tasks + PRO = "pro" # Production workloads + STANDARD = "standard" # General purpose, balanced + LITE = "lite" # Quick responses, simple tasks + + +@dataclass +class TierModelCandidate: + """ + A model candidate for a tier. + + Extensible design for future selection strategies: + - weight: For probability-based selection (higher = more likely) + - priority: For priority-based selection (lower = higher priority) + - capabilities: For capability-based matching + - is_fallback: Always-available fallback option + """ + + model: str + provider_type: ProviderType + is_fallback: bool = False + weight: float = 1.0 # For probability-based selection + priority: int = 0 # Lower = higher priority (for ordered selection) + capabilities: list[str] = field(default_factory=list) # e.g., ["coding", "creative", "reasoning"] + description: str = "" # Human-readable description for LLM selection + + +# Model for intelligent selection (Gemini 2.5 Flash) +MODEL_SELECTOR_MODEL = "qwen3-next-80b-a3b-instruct" +MODEL_SELECTOR_PROVIDER = ProviderType.QWEN + +# Model for topic title generation (fast, efficient model) +TOPIC_RENAME_MODEL = "qwen3-next-80b-a3b-instruct" +TOPIC_RENAME_PROVIDER = ProviderType.QWEN + + +# Tier-to-model candidates mapping +# Each tier has multiple candidates with a Gemini fallback +TIER_MODEL_CANDIDATES: dict[ModelTier, list[TierModelCandidate]] = { + ModelTier.ULTRA: [ + TierModelCandidate( + model="Vendor2/Claude-4.5-Opus", + provider_type=ProviderType.GPUGEEK, + priority=1, + capabilities=["reasoning", "creative", "coding"], + description="Best for coding and choose this for most tasks if no need to generate images", + ), + TierModelCandidate( + model="gemini-3-pro-image-preview", + provider_type=ProviderType.GOOGLE_VERTEX, + priority=2, + description="Must select this if user wants to generate images", + ), + TierModelCandidate( + model="gpt-5.2-pro", + provider_type=ProviderType.AZURE_OPENAI, + is_fallback=True, + priority=99, + capabilities=["coding", "analysis"], + description="Only use this for really complex reasoning, never use this for normal tasks", + ), + ], + ModelTier.PRO: [ + TierModelCandidate( + model="gemini-3-pro-preview", + provider_type=ProviderType.GOOGLE_VERTEX, + priority=1, + description="Choose this for most tasks if no need to generate images", + ), + TierModelCandidate( + model="gemini-2.5-flash-image", + provider_type=ProviderType.GOOGLE_VERTEX, + priority=2, + description="Must select this if user wants to generate images", + ), + TierModelCandidate( + model="gpt-5.2", + provider_type=ProviderType.AZURE_OPENAI, + priority=3, + capabilities=["coding", "analysis"], + description="Choose this model if user faces a task requiring complex reasoning", + ), + TierModelCandidate( + model="Vendor2/Claude-4.5-Sonnet", + provider_type=ProviderType.GPUGEEK, + is_fallback=True, + priority=99, + capabilities=["reasoning", "creative", "coding"], + description="Choose this model if user wants to code and write documents", + ), + ], + ModelTier.STANDARD: [ + TierModelCandidate( + model="gemini-3-flash-preview", + provider_type=ProviderType.GOOGLE_VERTEX, + priority=1, + capabilities=["general", "fast"], + description="Choose this for most tasks", + ), + TierModelCandidate( + model="qwen3-max", + provider_type=ProviderType.QWEN, + priority=2, + capabilities=["coding", "multilingual"], + description="Choose this if user uses Chinese", + ), + TierModelCandidate( + model="gpt-5-mini", + provider_type=ProviderType.AZURE_OPENAI, + is_fallback=True, + priority=99, + capabilities=["general"], + description="Choose this if this is just a simple task", + ), + ], + ModelTier.LITE: [ + TierModelCandidate( + model="DeepSeek/DeepSeek-V3.1-0821", + provider_type=ProviderType.GPUGEEK, + priority=4, + capabilities=["coding", "efficient"], + description="Choose this for most tasks", + ), + TierModelCandidate( + model="qwen3-30b-a3b", + provider_type=ProviderType.QWEN, + priority=1, + capabilities=["fast", "efficient"], + description="Choose this if user needs quick responses", + ), + TierModelCandidate( + model="gemini-2.5-flash-lite", + provider_type=ProviderType.GOOGLE_VERTEX, + is_fallback=True, + priority=99, + capabilities=["fast", "efficient"], + description="Choose this if this is just a simple task", + ), + TierModelCandidate( + model="gpt-5-nano", + provider_type=ProviderType.AZURE_OPENAI, + priority=3, + capabilities=["fast", "efficient"], + description="Choose this if user needs reasoning", + ), + ], +} + +TIER_MODEL_CONSUMPTION_RATE: dict[ModelTier, float] = { + ModelTier.ULTRA: 6.8, + ModelTier.PRO: 3.0, + ModelTier.STANDARD: 1.0, + ModelTier.LITE: 0.0, +} + + +def get_fallback_model_for_tier(tier: ModelTier) -> TierModelCandidate: + """Get the fallback model for a tier. + + Args: + tier: The model tier + + Returns: + The fallback TierModelCandidate for the tier + """ + candidates = TIER_MODEL_CANDIDATES.get(tier, TIER_MODEL_CANDIDATES[ModelTier.STANDARD]) + for candidate in candidates: + if candidate.is_fallback: + return candidate + # If no fallback defined, use last candidate + return candidates[-1] + + +def resolve_model_for_tier(tier: ModelTier) -> str: + """Resolve the default (fallback) model name for a given tier. + + This is a simple fallback that returns the tier's fallback model. + For intelligent selection, use the model_selector service. + + Args: + tier: The user-selected model tier + + Returns: + The fallback model name for this tier + """ + return get_fallback_model_for_tier(tier).model + + +def get_candidate_for_model(model_name: str) -> TierModelCandidate | None: + """Get the candidate definition for a specific model name. + + Args: + model_name: The model name to look up + + Returns: + The TierModelCandidate if found, else None + """ + for candidates in TIER_MODEL_CANDIDATES.values(): + for candidate in candidates: + if candidate.model == model_name: + return candidate + return None diff --git a/service/app/tasks/chat.py b/service/app/tasks/chat.py index 02094fb6..c44a93e1 100644 --- a/service/app/tasks/chat.py +++ b/service/app/tasks/chat.py @@ -13,10 +13,13 @@ from app.core.celery_app import celery_app from app.core.chat import get_ai_response_stream from app.core.consume import create_consume_for_chat +from app.core.consume_calculator import ConsumptionCalculator +from app.core.consume_strategy import ConsumptionContext from app.infra.database import ASYNC_DATABASE_URL from app.models.citation import CitationCreate from app.models.message import Message, MessageCreate from app.repos import CitationRepository, FileRepository, MessageRepository, TopicRepository +from app.repos.session import SessionRepository from app.schemas.chat_event_payloads import CitationData from app.schemas.chat_event_types import ChatEventType @@ -387,15 +390,22 @@ async def _process_chat_message_async( # Settlement try: - IMAGE_GENERATION_COST = 10 - generated_files_cost = generated_files_count * IMAGE_GENERATION_COST - - if total_tokens > 0: - token_cost = (input_tokens * 1 + output_tokens * 3) // 1000 - total_cost = 3 + token_cost + generated_files_count # base cost hardcoded to 3 as before - else: - base_cost = 3 # Hardcoded fallback - total_cost = int(base_cost + len(full_content) // 100 + generated_files_cost) + # Get session to retrieve model_tier + session_repo = SessionRepository(db) + session = await session_repo.get_session_by_id(session_id) + model_tier = session.model_tier if session else None + + # Use strategy pattern for consumption calculation + consume_context = ConsumptionContext( + model_tier=model_tier, + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + content_length=len(full_content), + generated_files_count=generated_files_count, + ) + result = ConsumptionCalculator.calculate(consume_context) + total_cost = result.amount remaining_amount = total_cost - pre_deducted_amount @@ -413,6 +423,9 @@ async def _process_chat_message_async( input_tokens=input_tokens if total_tokens > 0 else None, output_tokens=output_tokens if total_tokens > 0 else None, total_tokens=total_tokens if total_tokens > 0 else None, + model_tier=model_tier.value if model_tier else None, + tier_rate=result.breakdown.get("tier_rate"), + calculation_breakdown=json.dumps(result.breakdown), ) except ErrCodeError as e: if e.code == ErrCode.INSUFFICIENT_BALANCE: diff --git a/service/migrations/versions/ab49b572f009_add_tier_to_session.py b/service/migrations/versions/ab49b572f009_add_tier_to_session.py new file mode 100644 index 00000000..6d76086d --- /dev/null +++ b/service/migrations/versions/ab49b572f009_add_tier_to_session.py @@ -0,0 +1,39 @@ +"""add_tier_to_session + +Revision ID: ab49b572f009 +Revises: f0d7b93430e1 +Create Date: 2026-01-12 20:55:07.656593 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "ab49b572f009" +down_revision: Union[str, Sequence[str], None] = "f0d7b93430e1" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + # Create enum type first (PostgreSQL requires explicit creation) + modeltier_enum = sa.Enum("ULTRA", "PRO", "STANDARD", "LITE", name="modeltier") + modeltier_enum.create(op.get_bind(), checkfirst=True) + + op.add_column("session", sa.Column("model_tier", modeltier_enum, nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("session", "model_tier") + # Drop the enum type (PostgreSQL requires explicit cleanup) + op.execute("DROP TYPE IF EXISTS modeltier") + # ### end Alembic commands ### diff --git a/service/migrations/versions/c712246c7034_add_tier_pricing_to_consume.py b/service/migrations/versions/c712246c7034_add_tier_pricing_to_consume.py new file mode 100644 index 00000000..3c4bf848 --- /dev/null +++ b/service/migrations/versions/c712246c7034_add_tier_pricing_to_consume.py @@ -0,0 +1,40 @@ +"""add_tier_pricing_to_consume + +Revision ID: c712246c7034 +Revises: ab49b572f009 +Create Date: 2026-01-14 20:45:35.088242 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = "c712246c7034" +down_revision: Union[str, Sequence[str], None] = "ab49b572f009" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("consumerecord", sa.Column("model_tier", sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + op.add_column("consumerecord", sa.Column("tier_rate", sa.Float(), nullable=True)) + op.add_column( + "consumerecord", sa.Column("calculation_breakdown", sqlmodel.sql.sqltypes.AutoString(), nullable=True) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("consumerecord", "calculation_breakdown") + op.drop_column("consumerecord", "tier_rate") + op.drop_column("consumerecord", "model_tier") + # ### end Alembic commands ### diff --git a/service/tests/integration/test_handlers/test_provider_api.py b/service/tests/integration/test_handlers/test_provider_api.py index 5140a76b..3f4d827f 100644 --- a/service/tests/integration/test_handlers/test_provider_api.py +++ b/service/tests/integration/test_handlers/test_provider_api.py @@ -9,8 +9,8 @@ class TestProviderAPI: """Integration tests for Provider API endpoints.""" - async def test_create_provider_endpoint(self, async_client: AsyncClient): - """Test POST /api/v1/providers/""" + async def test_create_provider_endpoint_forbidden(self, async_client: AsyncClient): + """Test POST /api/v1/providers/ returns 403 (user providers disabled)""" payload = ProviderCreateFactory.build( scope=ProviderScope.USER, provider_type=ProviderType.OPENAI, @@ -20,61 +20,108 @@ async def test_create_provider_endpoint(self, async_client: AsyncClient): ).model_dump(mode="json") response = await async_client.post("/xyzen/api/v1/providers/", json=payload) - assert response.status_code == 201 + assert response.status_code == 403 data = response.json() - assert data["name"] == "Test API Provider" - assert data["id"] is not None - assert data["key"] == "sk-test-key" + assert "disabled" in data["detail"].lower() async def test_get_my_providers(self, async_client: AsyncClient): - """Test GET /api/v1/providers/me""" - # Create a provider first - payload = ProviderCreateFactory.build( - scope=ProviderScope.USER, provider_type=ProviderType.GOOGLE, name="My Google Provider" - ).model_dump(mode="json") - c_response = await async_client.post("/xyzen/api/v1/providers/", json=payload) - assert c_response.status_code == 201 - - # List providers + """Test GET /api/v1/providers/me returns only system providers""" + # List providers (should only return system providers) response = await async_client.get("/xyzen/api/v1/providers/me") assert response.status_code == 200 data = response.json() assert isinstance(data, list) - # Should find our new provider - found = any(p["name"] == "My Google Provider" for p in data) - assert found + # All returned providers should be system providers + for provider in data: + assert provider["is_system"] is True + + async def test_get_system_providers(self, async_client: AsyncClient): + """Test GET /api/v1/providers/system returns system providers or 404 if none configured""" + response = await async_client.get("/xyzen/api/v1/providers/system") + # 200 if system providers exist, 404 if none configured + assert response.status_code in (200, 404) + + if response.status_code == 200: + data = response.json() + assert isinstance(data, list) + # All returned providers should be system providers + for provider in data: + assert provider["is_system"] is True + # Sensitive data should be masked + assert provider["key"] == "••••••••" + assert provider["api"] == "•••••••••••••••••" async def test_get_provider_detail(self, async_client: AsyncClient): - """Test GET /api/v1/providers/{id}""" - # Create - payload = ProviderCreateFactory.build(scope=ProviderScope.USER, provider_type=ProviderType.OPENAI).model_dump( - mode="json" - ) + """Test GET /api/v1/providers/{id} for system provider""" + # Get list of system providers first + list_response = await async_client.get("/xyzen/api/v1/providers/system") + if list_response.status_code == 404: + pytest.skip("No system providers configured for testing") - c_response = await async_client.post("/xyzen/api/v1/providers/", json=payload) - provider_id = c_response.json()["id"] + assert list_response.status_code == 200 + providers = list_response.json() - # Get + if not providers: + pytest.skip("No system providers configured for testing") + + provider_id = providers[0]["id"] + + # Get specific provider response = await async_client.get(f"/xyzen/api/v1/providers/{provider_id}") assert response.status_code == 200 data = response.json() assert data["id"] == provider_id - assert data["key"] == payload["key"] + assert data["is_system"] is True + # Sensitive data should be masked for system providers + assert data["key"] == "••••••••" + + async def test_delete_provider_forbidden(self, async_client: AsyncClient): + """Test DELETE /api/v1/providers/{id} for system provider is forbidden""" + # Get list of system providers first + list_response = await async_client.get("/xyzen/api/v1/providers/system") + if list_response.status_code == 404: + pytest.skip("No system providers configured for testing") - async def test_delete_provider_endpoint(self, async_client: AsyncClient): - """Test DELETE /api/v1/providers/{id}""" - # Create - payload = ProviderCreateFactory.build(scope=ProviderScope.USER, provider_type=ProviderType.OPENAI).model_dump( - mode="json" - ) + assert list_response.status_code == 200 + providers = list_response.json() - c_response = await async_client.post("/xyzen/api/v1/providers/", json=payload) - provider_id = c_response.json()["id"] + if not providers: + pytest.skip("No system providers configured for testing") - # Delete + provider_id = providers[0]["id"] + + # Try to delete system provider (should fail) response = await async_client.delete(f"/xyzen/api/v1/providers/{provider_id}") - assert response.status_code == 204 + # Should return 403 (Forbidden) since system providers can't be deleted + assert response.status_code == 403 - # Verify deletion - get_response = await async_client.get(f"/xyzen/api/v1/providers/{provider_id}") - assert get_response.status_code == 404 + async def test_get_available_models(self, async_client: AsyncClient): + """Test GET /api/v1/providers/available-models returns models for system providers""" + response = await async_client.get("/xyzen/api/v1/providers/available-models") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, dict) + # Should return a dict mapping provider IDs to model lists + for provider_id, models in data.items(): + assert isinstance(models, list) + + async def test_get_provider_models(self, async_client: AsyncClient): + """Test GET /api/v1/providers/{provider_id}/models""" + # Get list of system providers first + list_response = await async_client.get("/xyzen/api/v1/providers/system") + if list_response.status_code == 404: + pytest.skip("No system providers configured for testing") + + assert list_response.status_code == 200 + providers = list_response.json() + + if not providers: + pytest.skip("No system providers configured for testing") + + provider_id = providers[0]["id"] + + # Get models for specific provider + response = await async_client.get(f"/xyzen/api/v1/providers/{provider_id}/models") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) diff --git a/service/tests/unit/core/chat/test_topic_generator_selector.py b/service/tests/unit/core/chat/test_topic_generator_selector.py deleted file mode 100644 index 65ff333c..00000000 --- a/service/tests/unit/core/chat/test_topic_generator_selector.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Unit tests for topic title auto-rename model selection.""" - -from __future__ import annotations - -import pytest - -from app.core.chat.topic_generator import _select_title_generation_model -from app.schemas.provider import ProviderType - - -class TestTopicGeneratorModelSelection: - @pytest.mark.parametrize( - ("provider_type", "session_model", "default_model", "expected"), - [ - (ProviderType.GOOGLE_VERTEX, "gemini-3-pro-image-preview", "gemini-3-pro", "gemini-2.5-flash"), - (ProviderType.AZURE_OPENAI, "gpt-5.2", "gpt-4.1", "gpt-5-mini"), - (None, "gpt-4.1", "gpt-5-mini", "gpt-4.1"), - (None, None, "gpt-5-mini", "gpt-5-mini"), - ], - ) - def test_select_title_generation_model( - self, - provider_type: ProviderType | None, - session_model: str | None, - default_model: str | None, - expected: str, - ) -> None: - assert ( - _select_title_generation_model( - provider_type=provider_type, - session_model=session_model, - default_model=default_model, - ) - == expected - ) diff --git a/service/tests/unit/test_core/test_consume_strategy.py b/service/tests/unit/test_core/test_consume_strategy.py new file mode 100644 index 00000000..2bcda4ee --- /dev/null +++ b/service/tests/unit/test_core/test_consume_strategy.py @@ -0,0 +1,239 @@ +"""Unit tests for consumption calculation strategies.""" + +import json + +from app.core.consume_calculator import ConsumptionCalculator +from app.core.consume_strategy import ( + ConsumptionContext, + TierBasedConsumptionStrategy, +) +from app.schemas.model_tier import TIER_MODEL_CONSUMPTION_RATE, ModelTier + + +class TestConsumptionContext: + """Tests for ConsumptionContext dataclass.""" + + def test_default_values(self) -> None: + """Test that ConsumptionContext has expected defaults.""" + context = ConsumptionContext() + assert context.model_tier is None + assert context.input_tokens == 0 + assert context.output_tokens == 0 + assert context.total_tokens == 0 + assert context.content_length == 0 + assert context.generated_files_count == 0 + + def test_with_values(self) -> None: + """Test ConsumptionContext with custom values.""" + context = ConsumptionContext( + model_tier=ModelTier.PRO, + input_tokens=1000, + output_tokens=500, + total_tokens=1500, + content_length=5000, + generated_files_count=2, + ) + assert context.model_tier == ModelTier.PRO + assert context.input_tokens == 1000 + assert context.output_tokens == 500 + assert context.total_tokens == 1500 + assert context.content_length == 5000 + assert context.generated_files_count == 2 + + +class TestTierBasedConsumptionStrategy: + """Tests for TierBasedConsumptionStrategy.""" + + def test_lite_tier_is_free(self) -> None: + """Test that LITE tier results in zero cost.""" + strategy = TierBasedConsumptionStrategy() + context = ConsumptionContext( + model_tier=ModelTier.LITE, + input_tokens=10000, + output_tokens=5000, + total_tokens=15000, + content_length=50000, + generated_files_count=5, + ) + result = strategy.calculate(context) + + assert result.amount == 0 + assert result.breakdown["tier_rate"] == 0.0 + assert result.breakdown["tier"] == "lite" + assert "note" in result.breakdown + + def test_standard_tier_base_multiplier(self) -> None: + """Test STANDARD tier with rate 1.0.""" + strategy = TierBasedConsumptionStrategy() + context = ConsumptionContext( + model_tier=ModelTier.STANDARD, + input_tokens=1000, + output_tokens=1000, + total_tokens=2000, + content_length=1000, + generated_files_count=0, + ) + result = strategy.calculate(context) + + # STANDARD rate is 1.0 + assert TIER_MODEL_CONSUMPTION_RATE[ModelTier.STANDARD] == 1.0 + + # Calculate expected: base(3) + tokens(1000*0.2/1000 + 1000*1/1000) = 3 + 0.2 + 1 = 4.2 + # With multiplier 1.0 = int(4.2) = 4 + expected_token_cost = (1000 * 0.2 / 1000) + (1000 * 1 / 1000) # 0.2 + 1 = 1.2 + expected = int((3 + expected_token_cost) * 1.0) + assert result.amount == expected + assert result.breakdown["tier_rate"] == 1.0 + + def test_pro_tier_multiplier(self) -> None: + """Test PRO tier with rate 3.0.""" + strategy = TierBasedConsumptionStrategy() + context = ConsumptionContext( + model_tier=ModelTier.PRO, + input_tokens=1000, + output_tokens=1000, + total_tokens=2000, + content_length=1000, + generated_files_count=0, + ) + result = strategy.calculate(context) + + # PRO rate is 3.0 + assert TIER_MODEL_CONSUMPTION_RATE[ModelTier.PRO] == 3.0 + + expected_token_cost = (1000 * 0.2 / 1000) + (1000 * 1 / 1000) # 1.2 + expected = int((3 + expected_token_cost) * 3.0) # 4.2 * 3 = 12.6 -> 12 + assert result.amount == expected + assert result.breakdown["tier_rate"] == 3.0 + + def test_ultra_tier_multiplier(self) -> None: + """Test ULTRA tier with rate 6.8.""" + strategy = TierBasedConsumptionStrategy() + context = ConsumptionContext( + model_tier=ModelTier.ULTRA, + input_tokens=1000, + output_tokens=1000, + total_tokens=2000, + content_length=1000, + generated_files_count=0, + ) + result = strategy.calculate(context) + + # ULTRA rate is 6.8 + assert TIER_MODEL_CONSUMPTION_RATE[ModelTier.ULTRA] == 6.8 + + expected_token_cost = (1000 * 0.2 / 1000) + (1000 * 1 / 1000) # 1.2 + expected = int((3 + expected_token_cost) * 6.8) # 4.2 * 6.8 = 28.56 -> 28 + assert result.amount == expected + assert result.breakdown["tier_rate"] == 6.8 + + def test_file_generation_cost(self) -> None: + """Test that file generation cost is included.""" + strategy = TierBasedConsumptionStrategy() + context = ConsumptionContext( + model_tier=ModelTier.STANDARD, + input_tokens=0, + output_tokens=0, + total_tokens=0, + content_length=0, + generated_files_count=2, + ) + result = strategy.calculate(context) + + # Base(3) + files(2*10) = 23, with rate 1.0 = 23 + expected = int((3 + 20) * 1.0) + assert result.amount == expected + assert result.breakdown["file_cost"] == 20 + + def test_no_tier_defaults_to_1(self) -> None: + """Test that None tier defaults to rate 1.0.""" + strategy = TierBasedConsumptionStrategy() + context = ConsumptionContext( + model_tier=None, + input_tokens=1000, + output_tokens=1000, + total_tokens=2000, + content_length=1000, + generated_files_count=0, + ) + result = strategy.calculate(context) + + # Should use default rate 1.0 + expected_token_cost = (1000 * 0.2 / 1000) + (1000 * 1 / 1000) # 1.2 + expected = int((3 + expected_token_cost) * 1.0) # 4 + assert result.amount == expected + assert result.breakdown["tier_rate"] == 1.0 + assert result.breakdown["tier"] == "default" + + def test_breakdown_contains_all_fields(self) -> None: + """Test that breakdown contains all expected fields.""" + strategy = TierBasedConsumptionStrategy() + context = ConsumptionContext( + model_tier=ModelTier.PRO, + input_tokens=1000, + output_tokens=500, + total_tokens=1500, + content_length=1000, + generated_files_count=1, + ) + result = strategy.calculate(context) + + assert "base_cost" in result.breakdown + assert "token_cost" in result.breakdown + assert "file_cost" in result.breakdown + assert "tier_rate" in result.breakdown + assert "tier" in result.breakdown + assert "pre_multiplier_total" in result.breakdown + + +class TestConsumptionCalculator: + """Tests for ConsumptionCalculator.""" + + def test_calculate_lite_tier_is_free(self) -> None: + """Test that LITE tier results in zero cost via calculator.""" + context = ConsumptionContext( + model_tier=ModelTier.LITE, + input_tokens=1000, + output_tokens=500, + total_tokens=1500, + ) + result = ConsumptionCalculator.calculate(context) + + assert result.amount == 0 + + def test_calculate_pro_tier(self) -> None: + """Test PRO tier calculation via calculator.""" + context = ConsumptionContext( + model_tier=ModelTier.PRO, + input_tokens=1000, + output_tokens=1000, + total_tokens=2000, + generated_files_count=0, + ) + result = ConsumptionCalculator.calculate(context) + + # PRO rate is 3.0 + expected_token_cost = (1000 * 0.2 / 1000) + (1000 * 1 / 1000) # 1.2 + expected = int((3 + expected_token_cost) * 3.0) # 12 + assert result.amount == expected + assert result.breakdown["tier_rate"] == 3.0 + + def test_breakdown_is_json_serializable(self) -> None: + """Test that breakdown can be serialized to JSON.""" + context = ConsumptionContext( + model_tier=ModelTier.PRO, + input_tokens=1000, + output_tokens=500, + total_tokens=1500, + generated_files_count=1, + ) + result = ConsumptionCalculator.calculate(context) + + # Should not raise + json_str = json.dumps(result.breakdown) + assert json_str is not None + assert len(json_str) > 0 + + # Should be valid JSON + parsed = json.loads(json_str) + assert parsed == result.breakdown diff --git a/service/tests/unit/test_schemas/test_model_tier.py b/service/tests/unit/test_schemas/test_model_tier.py new file mode 100644 index 00000000..aab2b930 --- /dev/null +++ b/service/tests/unit/test_schemas/test_model_tier.py @@ -0,0 +1,99 @@ +"""Unit tests for model tier resolution.""" + +from app.schemas.model_tier import ( + TIER_MODEL_CANDIDATES, + ModelTier, + get_fallback_model_for_tier, + resolve_model_for_tier, +) + + +class TestModelTier: + """Test ModelTier enum and resolution.""" + + def test_model_tier_values(self) -> None: + """Test that ModelTier has expected values.""" + assert ModelTier.ULTRA.value == "ultra" + assert ModelTier.PRO.value == "pro" + assert ModelTier.STANDARD.value == "standard" + assert ModelTier.LITE.value == "lite" + + def test_resolve_model_for_tier_ultra(self) -> None: + """Test ULTRA tier returns expected fallback model.""" + model = resolve_model_for_tier(ModelTier.ULTRA) + fallback = get_fallback_model_for_tier(ModelTier.ULTRA) + assert model == fallback.model + assert model is not None + assert len(model) > 0 + + def test_resolve_model_for_tier_pro(self) -> None: + """Test PRO tier returns expected fallback model.""" + model = resolve_model_for_tier(ModelTier.PRO) + fallback = get_fallback_model_for_tier(ModelTier.PRO) + assert model == fallback.model + assert model is not None + assert len(model) > 0 + + def test_resolve_model_for_tier_standard(self) -> None: + """Test STANDARD tier returns expected fallback model.""" + model = resolve_model_for_tier(ModelTier.STANDARD) + fallback = get_fallback_model_for_tier(ModelTier.STANDARD) + assert model == fallback.model + assert model is not None + assert len(model) > 0 + + def test_resolve_model_for_tier_lite(self) -> None: + """Test LITE tier returns expected fallback model.""" + model = resolve_model_for_tier(ModelTier.LITE) + fallback = get_fallback_model_for_tier(ModelTier.LITE) + assert model == fallback.model + assert model is not None + assert len(model) > 0 + + def test_all_tiers_have_mapping(self) -> None: + """Test that all tiers have a model mapping.""" + for tier in ModelTier: + assert tier in TIER_MODEL_CANDIDATES + model = resolve_model_for_tier(tier) + assert model is not None + assert len(model) > 0 + + def test_all_tiers_have_candidates(self) -> None: + """Test that all tiers have at least one candidate.""" + for tier in ModelTier: + candidates = TIER_MODEL_CANDIDATES[tier] + assert len(candidates) > 0 + # Each tier should have at least one candidate + assert all(c.model for c in candidates) + assert all(c.provider_type for c in candidates) + + def test_all_tiers_have_fallback(self) -> None: + """Test that all tiers have a fallback model.""" + for tier in ModelTier: + fallback = get_fallback_model_for_tier(tier) + assert fallback is not None + assert fallback.model is not None + assert len(fallback.model) > 0 + assert fallback.provider_type is not None + + def test_fallback_models_are_marked(self) -> None: + """Test that fallback models have is_fallback=True.""" + for tier in ModelTier: + candidates = TIER_MODEL_CANDIDATES[tier] + fallback_candidates = [c for c in candidates if c.is_fallback] + # Each tier should have at least one fallback + assert len(fallback_candidates) >= 1 + # Fallback should be retrievable + fallback = get_fallback_model_for_tier(tier) + assert fallback.is_fallback is True + + def test_candidate_priorities(self) -> None: + """Test that candidates have valid priorities.""" + for tier in ModelTier: + candidates = TIER_MODEL_CANDIDATES[tier] + for candidate in candidates: + # Priority should be non-negative + assert candidate.priority >= 0 + # Fallback should have high priority number (low priority) + if candidate.is_fallback: + assert candidate.priority >= 90 diff --git a/web/src/components/features/SettingsButton.tsx b/web/src/components/features/SettingsButton.tsx index 367314e1..67f3a11b 100644 --- a/web/src/components/features/SettingsButton.tsx +++ b/web/src/components/features/SettingsButton.tsx @@ -19,7 +19,7 @@ export const SettingsButton = ({ )} */} - {/* Model Selector */} + {/* Tier Selector */} {activeChatChannel && currentAgent && ( - )} diff --git a/web/src/components/layouts/components/TierSelector.tsx b/web/src/components/layouts/components/TierSelector.tsx new file mode 100644 index 00000000..8fd0e73e --- /dev/null +++ b/web/src/components/layouts/components/TierSelector.tsx @@ -0,0 +1,137 @@ +"use client"; + +import { ChevronDownIcon, CpuChipIcon } from "@heroicons/react/24/outline"; +import { AnimatePresence, motion } from "motion/react"; +import { useState } from "react"; +import { useTranslation } from "react-i18next"; + +export type ModelTier = "ultra" | "pro" | "standard" | "lite"; + +interface TierSelectorProps { + currentTier: ModelTier | null | undefined; + onTierChange: (tier: ModelTier) => void; + disabled?: boolean; +} + +interface TierConfig { + key: ModelTier; + bgColor: string; + textColor: string; + dotColor: string; +} + +const TIER_CONFIGS: TierConfig[] = [ + { + key: "ultra", + bgColor: "bg-purple-500/10 dark:bg-purple-500/20", + textColor: "text-purple-700 dark:text-purple-400", + dotColor: "bg-purple-500", + }, + { + key: "pro", + bgColor: "bg-blue-500/10 dark:bg-blue-500/20", + textColor: "text-blue-700 dark:text-blue-400", + dotColor: "bg-blue-500", + }, + { + key: "standard", + bgColor: "bg-green-500/10 dark:bg-green-500/20", + textColor: "text-green-700 dark:text-green-400", + dotColor: "bg-green-500", + }, + { + key: "lite", + bgColor: "bg-orange-500/10 dark:bg-orange-500/20", + textColor: "text-orange-700 dark:text-orange-400", + dotColor: "bg-orange-500", + }, +]; + +export function TierSelector({ + currentTier, + onTierChange, + disabled = false, +}: TierSelectorProps) { + const { t } = useTranslation(); + const [isOpen, setIsOpen] = useState(false); + + // Default to standard if no tier is selected + const effectiveTier = currentTier || "standard"; + const currentConfig = + TIER_CONFIGS.find((c) => c.key === effectiveTier) || TIER_CONFIGS[2]; + + const handleTierClick = (tier: ModelTier) => { + onTierChange(tier); + setIsOpen(false); + }; + + return ( +
!disabled && setIsOpen(true)} + onMouseLeave={() => setIsOpen(false)} + > + {/* Main Trigger Button */} + !disabled && setIsOpen(!isOpen)} + > + + + {t(`app.tierSelector.tiers.${effectiveTier}.name`)} + + + + + {/* Dropdown */} + + {isOpen && ( + +
+ {t("app.tierSelector.title")} +
+
+ {TIER_CONFIGS.map((config, index) => ( + handleTierClick(config.key)} + className={`flex w-full items-center gap-2 rounded-md px-3 py-2 text-left transition-colors ${ + effectiveTier === config.key + ? `${config.bgColor} ${config.textColor}` + : "hover:bg-neutral-100 dark:hover:bg-neutral-800" + }`} + > +
+
+
+ {t(`app.tierSelector.tiers.${config.key}.name`)} +
+
+ {t(`app.tierSelector.tiers.${config.key}.description`)} +
+
+ + ))} +
+ + )} + +
+ ); +} diff --git a/web/src/components/modals/AddLlmProviderModal.tsx b/web/src/components/modals/AddLlmProviderModal.tsx deleted file mode 100644 index 111d0887..00000000 --- a/web/src/components/modals/AddLlmProviderModal.tsx +++ /dev/null @@ -1,253 +0,0 @@ -import { Modal } from "@/components/animate-ui/components/animate/modal"; -import { Input } from "@/components/base/Input"; -import { useXyzen } from "@/store"; -import type { LlmProviderCreate } from "@/types/llmProvider"; -import { Button, Field, Label } from "@headlessui/react"; -import { useState, type ChangeEvent } from "react"; - -export function AddLlmProviderModal() { - const { isAddLlmProviderModalOpen, closeAddLlmProviderModal, addProvider } = - useXyzen(); - const [newProvider, setNewProvider] = useState< - Omit - >({ - name: "", - api: "", - key: "", - model: "", - max_tokens: 4096, - temperature: 0.7, - timeout: 60, - }); - const [error, setError] = useState(null); - const [loading, setLoading] = useState(false); - - const handleInputChange = (e: ChangeEvent) => { - const { name, value } = e.target; - setNewProvider((prev) => ({ - ...prev, - [name]: - name === "max_tokens" || name === "temperature" || name === "timeout" - ? value === "" - ? Number(0) - : Number(value) - : value, - })); - }; - - const handleReset = () => { - setNewProvider({ - name: "", - api: "", - key: "", - model: "", - max_tokens: 4096, - temperature: 0.7, - timeout: 60, - }); - setError(null); - }; - - const handleAddProvider = async () => { - setError(null); - setLoading(true); - - if ( - !newProvider.name || - !newProvider.api || - !newProvider.key || - !newProvider.model - ) { - setError("Name, API endpoint, API key, and model are required."); - setLoading(false); - return; - } - - try { - await addProvider({ - ...newProvider, - provider_type: "openai", - user_id: "", - } as LlmProviderCreate); - handleReset(); - closeAddLlmProviderModal(); - } catch (error) { - setError( - error instanceof Error ? error.message : "Failed to add provider", - ); - } finally { - setLoading(false); - } - }; - - const handleClose = () => { - handleReset(); - closeAddLlmProviderModal(); - }; - - const getProviderTemplates = (name: string) => { - const lowerName = name.toLowerCase(); - if (lowerName.includes("openai") && !lowerName.includes("azure")) { - return { - api: "https://api.openai.com/v1", - model: "gpt-4o-mini", - }; - } else if (lowerName.includes("azure")) { - return { - api: "https://your-resource.openai.azure.com/", - model: "gpt-4o", - }; - } else if ( - lowerName.includes("anthropic") || - lowerName.includes("claude") - ) { - return { - api: "https://api.anthropic.com", - model: "claude-3-haiku-20240307", - }; - } - return {}; - }; - - const handleNameChange = (e: ChangeEvent) => { - const { value } = e.target; - const templates = getProviderTemplates(value); - setNewProvider((prev) => ({ - ...prev, - name: value, - ...templates, - })); - }; - - return ( - -
- - - - - - - - - - - - - - - - - - - - -
- - - - - - - - - - - - - - -
- - {error && ( -
-

{error}

-
- )} - -
- - -
-
-
- ); -} diff --git a/web/src/components/modals/SettingsModal.tsx b/web/src/components/modals/SettingsModal.tsx index e6171092..4390fbab 100644 --- a/web/src/components/modals/SettingsModal.tsx +++ b/web/src/components/modals/SettingsModal.tsx @@ -3,7 +3,6 @@ import { useXyzen } from "@/store"; import { AdjustmentsHorizontalIcon, ArrowLeftIcon, - CloudIcon, GiftIcon, ServerStackIcon, } from "@heroicons/react/24/outline"; @@ -13,8 +12,6 @@ import { useTranslation } from "react-i18next"; import { LanguageSettings, - ProviderConfigForm, - ProviderList, RedemptionSettings, StyleSettings, ThemeSettings, @@ -30,8 +27,6 @@ export function SettingsModal() { activeSettingsCategory, setActiveSettingsCategory, activeUiSetting, - selectedProviderId, - setSelectedProvider, } = useXyzen(); // Mobile navigation state: 'categories' | 'content' @@ -46,11 +41,6 @@ export function SettingsModal() { label: t("settings.categories.ui"), icon: AdjustmentsHorizontalIcon, }, - { - id: "provider", - label: t("settings.categories.provider"), - icon: CloudIcon, - }, { id: "mcp", label: t("settings.categories.mcp"), @@ -144,47 +134,6 @@ export function SettingsModal() {
{activeSettingsCategory === "mcp" && } - {activeSettingsCategory === "provider" && ( -
- {/* Provider List Column */} -
- -
- {/* Provider Config Column */} -
- {/* Mobile Back Button */} -
- - - {t("settings.categories.provider")} - -
- - {selectedProviderId ? ( - - ) : ( -
- -

{t("settings.provider.emptyHint")}

-
- )} -
-
- )} - {activeSettingsCategory === "ui" && (
).azure_version; - return azureVersion === undefined || typeof azureVersion === "string"; -} - -function getAzureVersion(config: AzureProviderConfig): string { - return typeof config.azure_version === "string" - ? config.azure_version - : "2024-02-15-preview"; -} - -// Helper function to get default API endpoint for provider type -const getDefaultApiEndpoint = (providerType: string): string => { - const defaultEndpoints: Record = { - openai: "https://api.openai.com/v1", - azure_openai: "https://YOUR_RESOURCE.openai.azure.com", - google: "https://generativelanguage.googleapis.com", - google_vertex: "", - anthropic: "https://api.anthropic.com", - }; - return defaultEndpoints[providerType] || ""; -}; - -export const ProviderConfigForm = () => { - const { t } = useTranslation(); - const { selectedProviderId, setUserDefaultProvider, userDefaultProviderId } = - useXyzen(); - - // Use TanStack Query hooks for provider data - const { data: llmProviders = [] } = useMyProviders(); - const { data: providerTemplates = [] } = useProviderTemplates(); - const createProviderMutation = useCreateProvider(); - const updateProviderMutation = useUpdateProvider(); - const deleteProviderMutation = useDeleteProvider(); - - const [formData, setFormData] = useState>({ - name: "", - provider_type: "", - api: "", - key: "", - user_id: "", // This will be set by backend from auth token - provider_config: {}, - }); - - // Azure-specific config state - const [azureConfig, setAzureConfig] = useState({ - azure_version: "2024-02-15-preview", - }); - - const [loading, setLoading] = useState(false); - const [error, setError] = useState(null); - const [success, setSuccess] = useState(null); - const [isEditing, setIsEditing] = useState(false); - - // Load data when selected provider changes - useEffect(() => { - setError(null); - setSuccess(null); - - if (!selectedProviderId) { - // No selection - setFormData({ - name: "", - provider_type: "", - api: "", - key: "", - user_id: "", - provider_config: {}, - }); - setAzureConfig({ - azure_version: "2024-02-15-preview", - }); - setIsEditing(false); - return; - } - - if (selectedProviderId.startsWith("new:")) { - // Creating new provider from template - const templateType = selectedProviderId.replace("new:", ""); - const template = providerTemplates.find((t) => t.type === templateType); - if (template) { - setFormData({ - name: t("settings.providers.form.defaults.name", { - name: template.display_name, - }), - provider_type: template.type, - api: getDefaultApiEndpoint(template.type), - key: "", - user_id: "", - provider_config: {}, - }); - - // Initialize Azure config if Azure OpenAI - if (template.type === "azure_openai") { - setAzureConfig({ - azure_version: "2024-02-15-preview", - }); - } else { - setAzureConfig({ - azure_version: "2024-02-15-preview", - }); - } - - setIsEditing(false); - } - } else { - // Editing existing provider - const provider = llmProviders.find((p) => p.id === selectedProviderId); - if (provider) { - // Check if it's a system provider - if (provider.is_system) { - setError(t("settings.providers.form.errors.systemReadOnlyEdit")); - setFormData({ - name: provider.name, - provider_type: provider.provider_type, - api: provider.api, - key: "••••••••", // Mask the key - user_id: provider.user_id, - provider_config: provider.provider_config || {}, - }); - - // Load Azure config if exists - if ( - provider.provider_type === "azure_openai" && - provider.provider_config && - isAzureProviderConfig(provider.provider_config) - ) { - setAzureConfig({ - azure_version: getAzureVersion(provider.provider_config), - }); - } - - setIsEditing(false); // Prevent editing - return; - } - - setFormData({ - name: provider.name, - provider_type: provider.provider_type, - api: provider.api, - key: provider.key, - user_id: provider.user_id, - provider_config: provider.provider_config || {}, - }); - - // Load Azure config if exists - if ( - provider.provider_type === "azure_openai" && - provider.provider_config && - isAzureProviderConfig(provider.provider_config) - ) { - setAzureConfig({ - azure_version: getAzureVersion(provider.provider_config), - }); - } else { - setAzureConfig({ - azure_version: "2024-02-15-preview", - }); - } - - setIsEditing(true); - } - } - }, [selectedProviderId, providerTemplates, llmProviders, t]); - - const handleInputChange = (e: ChangeEvent) => { - const { name, value } = e.target; - setFormData((prev) => ({ - ...prev, - [name]: value, - })); - }; - - const handleAzureConfigChange = (e: ChangeEvent) => { - const { name, value } = e.target; - setAzureConfig((prev) => ({ - ...prev, - [name]: value, - })); - }; - - const handleSave = async () => { - setError(null); - setSuccess(null); - setLoading(true); - - try { - // Validation - if (!formData.name || !formData.api || !formData.key) { - setError(t("settings.providers.form.errors.required")); - setLoading(false); - return; - } - - if ( - isEditing && - selectedProviderId && - !selectedProviderId.startsWith("new:") - ) { - // Update existing provider - const updateData: LlmProviderUpdate = { - name: formData.name, - api: formData.api, - key: formData.key, - }; - - // Add Azure config if Azure OpenAI - if (formData.provider_type === "azure_openai") { - updateData.provider_config = { - azure_version: azureConfig.azure_version, - }; - } - - await updateProviderMutation.mutateAsync({ - id: selectedProviderId, - provider: updateData, - }); - setSuccess(t("settings.providers.form.success.updated")); - } else { - // Create new provider - const createData: LlmProviderCreate = { - scope: ProviderScope.USER, - name: formData.name!, - provider_type: formData.provider_type!, - api: formData.api!, - key: formData.key!, - user_id: "", // Backend will set this - }; - - // Add Azure config if Azure OpenAI - if (formData.provider_type === "azure_openai") { - createData.provider_config = { - azure_version: azureConfig.azure_version, - }; - } - - await createProviderMutation.mutateAsync(createData); - setSuccess(t("settings.providers.form.success.created")); - } - } catch (err) { - setError( - err instanceof Error - ? err.message - : t("settings.providers.form.errors.saveFailed"), - ); - } finally { - setLoading(false); - } - }; - - const handleDelete = async () => { - if (!selectedProviderId || selectedProviderId.startsWith("new:")) return; - - // Check if it's a system provider - const provider = llmProviders.find((p) => p.id === selectedProviderId); - if (provider?.is_system) { - setError(t("settings.providers.form.errors.systemReadOnlyDelete")); - return; - } - - if (!confirm(t("settings.providers.form.confirm.delete"))) return; - - setError(null); - setSuccess(null); - setLoading(true); - - try { - await deleteProviderMutation.mutateAsync(selectedProviderId); - setSuccess(t("settings.providers.form.success.deleted")); - } catch (err) { - setError( - err instanceof Error - ? err.message - : t("settings.providers.form.errors.deleteFailed"), - ); - } finally { - setLoading(false); - } - }; - - const handleSetDefault = async () => { - if (!selectedProviderId || selectedProviderId.startsWith("new:")) return; - - // Check if it's a system provider - const provider = llmProviders.find((p) => p.id === selectedProviderId); - if (provider?.is_system) { - setError(t("settings.providers.form.errors.systemReadOnlyDefault")); - return; - } - - setError(null); - setSuccess(null); - setLoading(true); - - try { - setUserDefaultProvider(selectedProviderId); - setSuccess(t("settings.providers.form.success.defaultSet")); - } catch (err) { - setError( - err instanceof Error - ? err.message - : t("settings.providers.form.errors.defaultFailed"), - ); - } finally { - setLoading(false); - } - }; - - if (!selectedProviderId) { - return ( -
-
-

{t("settings.providers.form.empty.title")}

-

- {t("settings.providers.form.empty.subtitle")} -

-
-
- ); - } - - const template = selectedProviderId.startsWith("new:") - ? providerTemplates.find( - (t) => t.type === selectedProviderId.replace("new:", ""), - ) - : providerTemplates.find((t) => t.type === formData.provider_type); - - // Check if current provider is a system provider (read-only) - const currentProvider = llmProviders.find((p) => p.id === selectedProviderId); - const isSystemProvider = currentProvider?.is_system || false; - - return ( -
-
-

- {isEditing - ? t("settings.providers.form.title.edit") - : t("settings.providers.form.title.create")} - {isSystemProvider && ( - - {t("settings.providers.form.systemBadge")} - - )} -

- - {error && ( -
- {error} -
- )} - - {success && ( -
- {success} -
- )} - -
- {/* Provider Type (readonly for existing) */} - {template && ( -
-
- {t("settings.providers.form.providerType.title")} -
-
- {template.display_name} -
-
- {t("settings.providers.form.providerType.availableModels", { - count: template.models.length, - })} -
-
- )} - - {/* Name */} - - - - - - {/* API Endpoint */} - - - -

- {t("settings.providers.form.fields.api.help")} -

-
- - {/* API Key */} - - - - - - {/* Azure OpenAI Specific Fields */} - {formData.provider_type === "azure_openai" && ( - - - -

- {t("settings.providers.form.fields.azureVersion.help")} -

-
- )} - - {/* Set as Default (for existing providers) */} - {isEditing && !isSystemProvider && ( - - - - - - - )} -
-
- - {/* Action Buttons */} -
-
-
- {isEditing && !isSystemProvider && ( - - )} -
-
- {!isSystemProvider && ( - - )} -
-
-
-
- ); -}; diff --git a/web/src/components/modals/settings/ProviderList.tsx b/web/src/components/modals/settings/ProviderList.tsx deleted file mode 100644 index 61ef4f08..00000000 --- a/web/src/components/modals/settings/ProviderList.tsx +++ /dev/null @@ -1,213 +0,0 @@ -import { - AnthropicIcon, - AzureIcon, - GoogleIcon, - OpenAIIcon, -} from "@/assets/icons"; -import { getProviderDisplayName } from "@/utils/providerDisplayNames"; -import { - Tabs, - TabsHighlight, - TabsHighlightItem, - TabsList, - TabsTrigger, -} from "@/components/animate-ui/primitives/radix/tabs"; -import { LoadingSpinner } from "@/components/base/LoadingSpinner"; -import { useXyzen } from "@/store"; -import { - useMyProviders, - useProviderTemplates, - useDeleteProvider, -} from "@/hooks/queries"; -import { - CheckCircleIcon, - PlusCircleIcon, - TrashIcon, -} from "@heroicons/react/24/outline"; -import { useTranslation } from "react-i18next"; - -export const ProviderList = () => { - const { t } = useTranslation(); - const { setSelectedProvider, selectedProviderId } = useXyzen(); - - // Use TanStack Query hooks for provider data - const { data: llmProviders = [], isLoading: llmProvidersLoading } = - useMyProviders(); - const { data: providerTemplates = [], isLoading: templatesLoading } = - useProviderTemplates(); - const deleteProviderMutation = useDeleteProvider(); - - const getProviderIcon = (type: string) => { - const iconClass = "h-5 w-5"; - switch (type) { - case "google": - return ; - case "openai": - return ; - case "google_vertex": - return ; - case "azure_openai": - return ; - case "anthropic": - return ; - case "gpugeek": - return ( -
X
- ); - case "qwen": - return ( -
Q
- ); - default: - return ; - } - }; - - const onValueChange = (value: string) => { - setSelectedProvider(value); - }; - - if (templatesLoading || llmProvidersLoading) { - return ( -
- -
- ); - } - - // Hide system providers from end users. - // Some backends may omit `is_system`; fall back to provider_type conventions. - const myProviders = llmProviders.filter((p) => { - const providerType = p.provider_type; - const providerName = p.name.toLowerCase(); - const isSystemType = - providerName === "system" || providerType?.startsWith("system_"); - return !(p.is_system || isSystemType); - }); - // Filter out system templates if any (though usually templates are for creation) - const availableTemplates = providerTemplates.filter((t) => { - if (t.type === "system") return false; - if (t.type.startsWith("system_")) return false; - return true; - }); - - const triggerBaseClassName = - "group relative z-10 flex w-full items-center justify-between rounded-lg px-4 py-3 text-left text-sm font-medium transition-all hover:bg-neutral-200/50 dark:hover:bg-neutral-800/50 data-[state=active]:bg-white data-[state=active]:text-indigo-600 data-[state=active]:shadow-sm data-[state=active]:ring-1 data-[state=active]:ring-neutral-200 dark:data-[state=active]:bg-neutral-800 dark:data-[state=active]:text-indigo-400 dark:data-[state=active]:ring-neutral-700"; - - return ( - -
- - {/* My Providers Section */} - {myProviders.length > 0 && ( -
-

- {t("settings.providers.my.title")} -

- - {myProviders.map((provider) => ( - -
- -
-
- {getProviderIcon(provider.provider_type)} -
-
-
- {/*{provider.is_system - ? getProviderDisplayName(provider.provider_type) - : provider.name}*/} - {getProviderDisplayName(provider.provider_type)} -
-
-
-
- {provider.is_default && ( - - )} -
-
- - -
-
- ))} -
-
- )} - - {/* Provider Templates Section */} -
-

- {t("settings.providers.templates.title")} -

- - {availableTemplates.map((template) => ( - - -
- {getProviderIcon(template.type)} -
-
-
- - {getProviderDisplayName(template.type)} - - -
-
- {t("settings.providers.templates.supports", { - count: template.models.length, - })} -
-
-
-
- ))} -
-
-
-
-
- ); -}; diff --git a/web/src/components/modals/settings/index.ts b/web/src/components/modals/settings/index.ts index 74f93292..a7ea625f 100644 --- a/web/src/components/modals/settings/index.ts +++ b/web/src/components/modals/settings/index.ts @@ -1,6 +1,4 @@ export { LanguageSettings } from "./LanguageSettings"; -export { ProviderConfigForm } from "./ProviderConfigForm"; -export { ProviderList } from "./ProviderList"; export { RedemptionSettings } from "./RedemptionSettings"; export { StyleSettings } from "./StyleSettings"; export { ThemeSettings } from "./ThemeSettings"; diff --git a/web/src/core/session/types.ts b/web/src/core/session/types.ts index 74ec785d..9deb50f9 100644 --- a/web/src/core/session/types.ts +++ b/web/src/core/session/types.ts @@ -14,6 +14,7 @@ export interface SessionResponse { user_id: string; provider_id?: string; model?: string; + model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; created_at: string; updated_at: string; @@ -39,6 +40,7 @@ export interface SessionCreatePayload { agent_id?: string; provider_id?: string; model?: string; + model_tier?: "ultra" | "pro" | "standard" | "lite"; mcp_server_ids?: string[]; } diff --git a/web/src/hooks/queries/index.ts b/web/src/hooks/queries/index.ts index 49a5601d..328b4b84 100644 --- a/web/src/hooks/queries/index.ts +++ b/web/src/hooks/queries/index.ts @@ -11,6 +11,7 @@ export { queryKeys } from "./queryKeys"; // Provider queries export { useMyProviders, + useSystemProviders, useProviderTemplates, useAvailableModels, useDefaultModelConfig, diff --git a/web/src/hooks/queries/queryKeys.ts b/web/src/hooks/queries/queryKeys.ts index 607c3ffb..b18348ff 100644 --- a/web/src/hooks/queries/queryKeys.ts +++ b/web/src/hooks/queries/queryKeys.ts @@ -34,6 +34,7 @@ export const queryKeys = { providers: { all: ["providers"] as const, my: () => [...queryKeys.providers.all, "my"] as const, + system: () => [...queryKeys.providers.all, "system"] as const, templates: () => [...queryKeys.providers.all, "templates"] as const, models: () => [...queryKeys.providers.all, "models"] as const, defaultConfig: () => [...queryKeys.providers.all, "defaultConfig"] as const, diff --git a/web/src/hooks/queries/useProvidersQuery.ts b/web/src/hooks/queries/useProvidersQuery.ts index e6b6aa76..e75675c7 100644 --- a/web/src/hooks/queries/useProvidersQuery.ts +++ b/web/src/hooks/queries/useProvidersQuery.ts @@ -27,6 +27,22 @@ export function useMyProviders() { }); } +/** + * Fetch system providers only (no user-defined providers) + * + * @example + * ```tsx + * const { data: providers, isLoading, error } = useSystemProviders(); + * ``` + */ +export function useSystemProviders() { + return useQuery({ + queryKey: queryKeys.providers.system(), + queryFn: () => llmProviderService.getSystemProviders(), + staleTime: 5 * 60 * 1000, // Consider fresh for 5 minutes + }); +} + /** * Fetch provider templates for creating new providers */ diff --git a/web/src/i18n/locales/en/app.json b/web/src/i18n/locales/en/app.json index c40efbb0..2bbae6e5 100644 --- a/web/src/i18n/locales/en/app.json +++ b/web/src/i18n/locales/en/app.json @@ -58,5 +58,26 @@ "modelSelector": { "noProvider": "Please add an LLM provider first", "notConfigured": "Not Configured" + }, + "tierSelector": { + "title": "Select Model Tier", + "tiers": { + "ultra": { + "name": "Xyzen Ultra", + "description": "Most powerful models for complex tasks" + }, + "pro": { + "name": "Xyzen Pro", + "description": "Advanced models for professional use" + }, + "standard": { + "name": "Xyzen Standard", + "description": "Balanced performance and efficiency" + }, + "lite": { + "name": "Xyzen Lite", + "description": "Fast and lightweight for simple tasks" + } + } } } diff --git a/web/src/i18n/locales/ja/app.json b/web/src/i18n/locales/ja/app.json index 94fea3fb..9c8dd9f6 100644 --- a/web/src/i18n/locales/ja/app.json +++ b/web/src/i18n/locales/ja/app.json @@ -58,5 +58,26 @@ "modelSelector": { "noProvider": "まずLLMプロバイダーを追加してください", "notConfigured": "未設定" + }, + "tierSelector": { + "title": "モデルティアを選択", + "tiers": { + "ultra": { + "name": "Xyzen ウルトラ", + "description": "複雑なタスク向けの最強モデル" + }, + "pro": { + "name": "Xyzen プロ", + "description": "プロフェッショナル向けの高度なモデル" + }, + "standard": { + "name": "Xyzen スタンダード", + "description": "性能と効率のバランス" + }, + "lite": { + "name": "Xyzen ライト", + "description": "シンプルなタスク向けの高速軽量モデル" + } + } } } diff --git a/web/src/i18n/locales/zh/app.json b/web/src/i18n/locales/zh/app.json index d87d87c6..5cb321e0 100644 --- a/web/src/i18n/locales/zh/app.json +++ b/web/src/i18n/locales/zh/app.json @@ -58,5 +58,26 @@ "modelSelector": { "noProvider": "请先添加LLM提供商", "notConfigured": "未设置" + }, + "tierSelector": { + "title": "选择模型等级", + "tiers": { + "ultra": { + "name": "Xyzen Ultra", + "description": "最强大的模型,适合复杂任务" + }, + "pro": { + "name": "Xyzen Pro", + "description": "高级模型,适合专业用途" + }, + "standard": { + "name": "Xyzen Standard", + "description": "性能与效率均衡" + }, + "lite": { + "name": "Xyzen Lite", + "description": "快速轻量,适合简单任务" + } + } } } diff --git a/web/src/service/llmProviderService.ts b/web/src/service/llmProviderService.ts index 926bb118..d96157fa 100644 --- a/web/src/service/llmProviderService.ts +++ b/web/src/service/llmProviderService.ts @@ -124,6 +124,22 @@ class LlmProviderService { return response.json(); } + /** + * Get system providers only (no user-defined providers) + */ + async getSystemProviders(): Promise { + const response = await fetch( + `${this.getBackendUrl()}/xyzen/api/v1/providers/system`, + { + headers: this.createAuthHeaders(), + }, + ); + if (!response.ok) { + throw new Error("Failed to fetch system providers"); + } + return response.json(); + } + /** * Create a new provider */ diff --git a/web/src/service/sessionService.ts b/web/src/service/sessionService.ts index b128a8e8..2d8feade 100644 --- a/web/src/service/sessionService.ts +++ b/web/src/service/sessionService.ts @@ -8,6 +8,7 @@ export interface SessionCreate { agent_id?: string; provider_id?: string; model?: string; + model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; } @@ -17,6 +18,7 @@ export interface SessionUpdate { is_active?: boolean; provider_id?: string; model?: string; + model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; } @@ -29,6 +31,7 @@ export interface SessionRead { user_id: string; provider_id?: string; model?: string; + model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; created_at: string; updated_at: string; diff --git a/web/src/store/slices/chatSlice.ts b/web/src/store/slices/chatSlice.ts index 20fb3f11..dcde7c55 100644 --- a/web/src/store/slices/chatSlice.ts +++ b/web/src/store/slices/chatSlice.ts @@ -65,6 +65,7 @@ export interface ChatSlice { config: { provider_id?: string; model?: string; + model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; }, ) => Promise; @@ -1971,6 +1972,8 @@ export const createChatSlice: StateCreator< state.channels[activeChannelId].provider_id = updatedSession.provider_id; state.channels[activeChannelId].model = updatedSession.model; + state.channels[activeChannelId].model_tier = + updatedSession.model_tier; state.channels[activeChannelId].google_search_enabled = updatedSession.google_search_enabled; } diff --git a/web/src/store/types.ts b/web/src/store/types.ts index 50c195ed..4bedff71 100644 --- a/web/src/store/types.ts +++ b/web/src/store/types.ts @@ -157,6 +157,7 @@ export interface ChatChannel { agentId?: string; provider_id?: string; model?: string; + model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; knowledgeContext?: KnowledgeContext; connected: boolean; @@ -200,6 +201,7 @@ export interface SessionResponse { agent_id?: string; provider_id?: string; model?: string; + model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; topics: TopicResponse[]; } From 6c8e1fe32bedf631781974a737b3f7f31f452f77 Mon Sep 17 00:00:00 2001 From: Harvey Date: Thu, 15 Jan 2026 01:33:46 +0800 Subject: [PATCH 06/11] feat: enhance agent node with stats display and layout persistence - Added StatsDisplay component to visualize agent statistics (message count, topic count, token usage) based on grid size. - Updated AgentNode to integrate StatsDisplay and improve layout handling. - Introduced AgentSpatialLayout type for managing agent layout in sessions. - Enhanced agent slice to fetch and store agent stats and spatial layout from the backend. - Implemented session statistics aggregation for agents, including message and token counts. - Added new migrations for session spatial layout and status. - Created AddAgentButton and SaveStatusIndicator components for improved UI interactions. --- service/app/api/v1/agents.py | 201 +++++++++------ service/app/core/session/service.py | 1 + service/app/models/__init__.py | 4 + service/app/models/session_stats.py | 48 ++++ service/app/models/sessions.py | 11 +- service/app/repos/session.py | 1 + service/app/repos/session_stats.py | 217 ++++++++++++++++ ...16fcd09524ba_add_session_spatial_layout.py | 33 +++ .../e427ec7ce799_add_session_status.py | 30 +++ web/src/app/chat/SpatialWorkspace.tsx | 238 ++++++++++++------ web/src/app/chat/spatial/AddAgentButton.tsx | 39 +++ web/src/app/chat/spatial/AgentNode.tsx | 203 ++++++++++++--- .../app/chat/spatial/SaveStatusIndicator.tsx | 75 ++++++ web/src/app/chat/spatial/types.ts | 38 ++- web/src/core/session/types.ts | 1 + web/src/service/sessionService.ts | 4 + web/src/store/slices/agentSlice.ts | 195 +++++++++++++- web/src/store/types.ts | 3 +- web/src/types/agents.ts | 55 ++++ 19 files changed, 1210 insertions(+), 187 deletions(-) create mode 100644 service/app/models/session_stats.py create mode 100644 service/app/repos/session_stats.py create mode 100644 service/migrations/versions/16fcd09524ba_add_session_spatial_layout.py create mode 100644 service/migrations/versions/e427ec7ce799_add_session_status.py create mode 100644 web/src/app/chat/spatial/AddAgentButton.tsx create mode 100644 web/src/app/chat/spatial/SaveStatusIndicator.tsx diff --git a/service/app/api/v1/agents.py b/service/app/api/v1/agents.py index 83a2f960..1e4a3013 100644 --- a/service/app/api/v1/agents.py +++ b/service/app/api/v1/agents.py @@ -9,6 +9,8 @@ - DELETE /{agent_id}: Delete an agent. - GET /system/chat: Get the user's default chat agent. - GET /system/all: Get all user default agents. +- GET /stats: Get aggregated stats for all agents (from sessions/messages). +- GET /{agent_id}/stats: Get aggregated stats for a specific agent. """ from uuid import UUID @@ -16,16 +18,18 @@ from fastapi import APIRouter, Depends, HTTPException from sqlmodel.ext.asyncio.session import AsyncSession +from app.agents.types import SystemAgentInfo from app.common.code import ErrCodeError, handle_auth_error from app.core.auth import AuthorizationService, get_auth_service from app.core.system_agent import SystemAgentManager from app.infra.database import get_session from app.middleware.auth import get_current_user -from app.agents.types import SystemAgentInfo from app.models.agent import AgentCreate, AgentRead, AgentReadWithDetails, AgentScope, AgentUpdate +from app.models.session_stats import AgentStatsAggregated from app.repos import AgentRepository, KnowledgeSetRepository, ProviderRepository from app.repos.agent_marketplace import AgentMarketplaceRepository from app.repos.session import SessionRepository +from app.repos.session_stats import SessionStatsRepository router = APIRouter(tags=["agents"]) @@ -206,6 +210,116 @@ async def create_agent_from_template( return AgentRead(**created_agent.model_dump()) +@router.get("/stats", response_model=dict[str, AgentStatsAggregated]) +async def get_all_agent_stats( + user: str = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, AgentStatsAggregated]: + """ + Get aggregated stats for all agents the user has interacted with. + + Stats are computed by aggregating data from sessions, topics, and messages. + Returns a dictionary mapping agent_id to aggregated stats. + + Args: + user: Authenticated user ID (injected by dependency) + db: Database session (injected by dependency) + + Returns: + dict[str, AgentStatsAggregated]: Dictionary of agent_id -> aggregated stats + """ + stats_repo = SessionStatsRepository(db) + return await stats_repo.get_all_agent_stats_for_user(user) + + +@router.get("/system/chat", response_model=AgentReadWithDetails) +async def get_system_chat_agent( + user: str = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> AgentReadWithDetails: + """ + Get the user's default chat agent. + + Returns the user's personal copy of the "随便聊聊" agent with MCP server details. + If it doesn't exist, it will be initialized. + + Args: + user: Authenticated user ID (injected by dependency) + db: Database session (injected by dependency) + + Returns: + AgentReadWithDetails: The user's chat agent with MCP server details + + Raises: + HTTPException: 404 if chat agent not found + """ + agent_repo = AgentRepository(db) + agents = await agent_repo.get_agents_by_user(user) + + chat_agent = next((a for a in agents if a.tags and "default_chat" in a.tags), None) + + if not chat_agent: + system_manager = SystemAgentManager(db) + new_agents = await system_manager.ensure_user_default_agents(user) + await db.commit() + chat_agent = next((a for a in new_agents if a.tags and "default_chat" in a.tags), None) + + if not chat_agent: + raise HTTPException(status_code=404, detail="Chat agent not found") + + # Get MCP servers for the agent + mcp_servers = await agent_repo.get_agent_mcp_servers(chat_agent.id) + + # Create agent dict with MCP servers + agent_dict = chat_agent.model_dump() + agent_dict["mcp_servers"] = mcp_servers + return AgentReadWithDetails(**agent_dict) + + +@router.get("/system/all", response_model=list[AgentReadWithDetails]) +async def get_all_system_agents( + user: str = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> list[AgentReadWithDetails]: + """ + Get all default agents for the user. + + Returns the user's personal copies of system agents with MCP server details. + These are the agents tagged with 'default_'. + + Args: + user: Authenticated user ID (injected by dependency) + db: Database session (injected by dependency) + + Returns: + list[AgentReadWithDetails]: list of all user default agents with MCP server details + """ + agent_repo = AgentRepository(db) + agents = await agent_repo.get_agents_by_user(user) + + # Filter for default agents + default_agents = [a for a in agents if a.tags and any(t.startswith("default_") for t in a.tags)] + + if not default_agents: + system_manager = SystemAgentManager(db) + default_agents = await system_manager.ensure_user_default_agents(user) + await db.commit() + + # Load MCP servers for each system agent + agents_with_details = [] + + for agent in default_agents: + # Get MCP servers for this agent + mcp_servers = await agent_repo.get_agent_mcp_servers(agent.id) + + # Create agent dict with MCP servers + agent_dict = agent.model_dump() + agent_dict["mcp_servers"] = mcp_servers + agents_with_details.append(AgentReadWithDetails(**agent_dict)) + + return agents_with_details + + @router.get("/{agent_id}", response_model=AgentReadWithDetails) async def get_agent( agent_id: UUID, @@ -360,89 +474,34 @@ async def delete_agent( raise handle_auth_error(e) -@router.get("/system/chat", response_model=AgentReadWithDetails) -async def get_system_chat_agent( +@router.get("/{agent_id}/stats", response_model=AgentStatsAggregated) +async def get_agent_stats( + agent_id: UUID, user: str = Depends(get_current_user), db: AsyncSession = Depends(get_session), -) -> AgentReadWithDetails: +) -> AgentStatsAggregated: """ - Get the user's default chat agent. + Get aggregated stats for a specific agent. - Returns the user's personal copy of the "随便聊聊" agent with MCP server details. - If it doesn't exist, it will be initialized. + Stats are computed by aggregating data from sessions, topics, and messages + across all sessions the user has with this agent. Args: + agent_id: The UUID of the agent user: Authenticated user ID (injected by dependency) db: Database session (injected by dependency) Returns: - AgentReadWithDetails: The user's chat agent with MCP server details + AgentStatsAggregated: The agent's aggregated usage statistics Raises: - HTTPException: 404 if chat agent not found + HTTPException: 404 if agent not found or not owned by user """ agent_repo = AgentRepository(db) - agents = await agent_repo.get_agents_by_user(user) - - chat_agent = next((a for a in agents if a.tags and "default_chat" in a.tags), None) - - if not chat_agent: - system_manager = SystemAgentManager(db) - new_agents = await system_manager.ensure_user_default_agents(user) - await db.commit() - chat_agent = next((a for a in new_agents if a.tags and "default_chat" in a.tags), None) - - if not chat_agent: - raise HTTPException(status_code=404, detail="Chat agent not found") + agent = await agent_repo.get_agent_by_id(agent_id) - # Get MCP servers for the agent - mcp_servers = await agent_repo.get_agent_mcp_servers(chat_agent.id) - - # Create agent dict with MCP servers - agent_dict = chat_agent.model_dump() - agent_dict["mcp_servers"] = mcp_servers - return AgentReadWithDetails(**agent_dict) + if not agent or agent.user_id != user: + raise HTTPException(status_code=404, detail="Agent not found") - -@router.get("/system/all", response_model=list[AgentReadWithDetails]) -async def get_all_system_agents( - user: str = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> list[AgentReadWithDetails]: - """ - Get all default agents for the user. - - Returns the user's personal copies of system agents with MCP server details. - These are the agents tagged with 'default_'. - - Args: - user: Authenticated user ID (injected by dependency) - db: Database session (injected by dependency) - - Returns: - list[AgentReadWithDetails]: list of all user default agents with MCP server details - """ - agent_repo = AgentRepository(db) - agents = await agent_repo.get_agents_by_user(user) - - # Filter for default agents - default_agents = [a for a in agents if a.tags and any(t.startswith("default_") for t in a.tags)] - - if not default_agents: - system_manager = SystemAgentManager(db) - default_agents = await system_manager.ensure_user_default_agents(user) - await db.commit() - - # Load MCP servers for each system agent - agents_with_details = [] - - for agent in default_agents: - # Get MCP servers for this agent - mcp_servers = await agent_repo.get_agent_mcp_servers(agent.id) - - # Create agent dict with MCP servers - agent_dict = agent.model_dump() - agent_dict["mcp_servers"] = mcp_servers - agents_with_details.append(AgentReadWithDetails(**agent_dict)) - - return agents_with_details + stats_repo = SessionStatsRepository(db) + return await stats_repo.get_agent_stats(agent_id, user) diff --git a/service/app/core/session/service.py b/service/app/core/session/service.py index 2c41ef8e..7a181371 100644 --- a/service/app/core/session/service.py +++ b/service/app/core/session/service.py @@ -33,6 +33,7 @@ async def create_session_with_default_topic(self, session_data: SessionCreate, u agent_id=agent_uuid, provider_id=session_data.provider_id, model=session_data.model, + spatial_layout=session_data.spatial_layout, google_search_enabled=session_data.google_search_enabled, ) diff --git a/service/app/models/__init__.py b/service/app/models/__init__.py index 1a50ce2f..afdef71c 100644 --- a/service/app/models/__init__.py +++ b/service/app/models/__init__.py @@ -34,6 +34,7 @@ ) from .provider import Provider from .redemption import RedemptionCode, RedemptionHistory, UserWallet +from .session_stats import AgentStatsAggregated, SessionStatsRead, UserStatsAggregated from .sessions import Session, SessionReadWithTopics from .smithery_cache import SmitheryServersCache from .tool import Tool, ToolFunction, ToolVersion @@ -87,6 +88,9 @@ "MessageReadWithFilesAndCitations", "Provider", "Session", + "AgentStatsAggregated", + "SessionStatsRead", + "UserStatsAggregated", "SessionReadWithTopics", "Tool", "ToolVersion", diff --git a/service/app/models/session_stats.py b/service/app/models/session_stats.py new file mode 100644 index 00000000..92946606 --- /dev/null +++ b/service/app/models/session_stats.py @@ -0,0 +1,48 @@ +""" +SessionStats schema for aggregated session usage statistics. + +This is NOT a database table - stats are computed by aggregating data from: +- sessions: session count per agent +- messages: message count per session/agent +- consume: token usage aggregated from consumption records + +The schemas here are used for API responses only. +""" + +from uuid import UUID + +from pydantic import BaseModel + + +class SessionStatsRead(BaseModel): + """Read schema for session statistics (aggregated, not stored).""" + + session_id: UUID + agent_id: UUID | None + topic_count: int = 0 + message_count: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + + +class AgentStatsAggregated(BaseModel): + """Aggregated stats for an agent across all sessions.""" + + agent_id: UUID + session_count: int = 0 + topic_count: int = 0 + message_count: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + + +class UserStatsAggregated(BaseModel): + """Aggregated stats for a user across all agents.""" + + user_id: str + agent_count: int = 0 + session_count: int = 0 + topic_count: int = 0 + message_count: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 diff --git a/service/app/models/sessions.py b/service/app/models/sessions.py index 7636a228..8daf47d9 100644 --- a/service/app/models/sessions.py +++ b/service/app/models/sessions.py @@ -1,9 +1,9 @@ import hashlib from datetime import datetime, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 -from sqlalchemy import TIMESTAMP +from sqlalchemy import JSON, TIMESTAMP from sqlmodel import Column, Field, SQLModel if TYPE_CHECKING: @@ -79,6 +79,11 @@ class SessionBase(SQLModel): google_search_enabled: bool = Field( default=False, description="Enable built-in web search for supported models (e.g., Gemini)" ) + spatial_layout: dict[str, Any] | None = Field( + default=None, + sa_column=Column(JSON, nullable=True), + description="Optional JSON blob for spatial UI layout (e.g., agent node positions, widget sizes)", + ) class Session(SessionBase, table=True): @@ -102,6 +107,7 @@ class SessionCreate(SQLModel): model: str | None = None model_tier: ModelTier | None = None google_search_enabled: bool = False + spatial_layout: dict[str, Any] | None = None class SessionRead(SessionBase): @@ -125,3 +131,4 @@ class SessionUpdate(SQLModel): model: str | None = None model_tier: ModelTier | None = None google_search_enabled: bool | None = None + spatial_layout: dict[str, Any] | None = None diff --git a/service/app/repos/session.py b/service/app/repos/session.py index a9b08e0e..40b78bdd 100644 --- a/service/app/repos/session.py +++ b/service/app/repos/session.py @@ -95,6 +95,7 @@ async def create_session(self, session_data: SessionCreate, user_id: str) -> Ses agent_id=agent_id, provider_id=session_data.provider_id, model=session_data.model, + spatial_layout=getattr(session_data, "spatial_layout", None), google_search_enabled=session_data.google_search_enabled, ) self.db.add(session) diff --git a/service/app/repos/session_stats.py b/service/app/repos/session_stats.py new file mode 100644 index 00000000..43678afa --- /dev/null +++ b/service/app/repos/session_stats.py @@ -0,0 +1,217 @@ +""" +Repository for session statistics aggregation. + +Computes stats by querying sessions, topics, messages, and consume tables. +No separate stats table needed - all data is aggregated on demand using +efficient database-level aggregation queries. + +Performance Note: Uses JOIN + GROUP BY for efficient single-pass aggregation. +Type hints are relaxed (type: ignore) because SQLAlchemy's typing doesn't +fully support all aggregation patterns, but runtime behavior is correct. +""" + +import logging +from uuid import UUID + +from sqlalchemy import and_, func +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.consume import ConsumeRecord +from app.models.message import Message +from app.models.session_stats import AgentStatsAggregated, SessionStatsRead +from app.models.sessions import Session +from app.models.topic import Topic + +logger = logging.getLogger(__name__) + + +class SessionStatsRepository: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def get_session_stats(self, session_id: UUID) -> SessionStatsRead | None: + """ + Get aggregated stats for a specific session. + + Uses efficient database aggregation to compute counts in a single query. + """ + # Get session info + session_stmt = select(Session).where(col(Session.id) == session_id) + session_result = await self.db.exec(session_stmt) + session = session_result.first() + + if not session: + return None + + # Count topics in this session + topic_count_stmt = select(func.count(Topic.id)).where(col(Topic.session_id) == session_id) # type: ignore + topic_result = await self.db.exec(topic_count_stmt) + topic_count = topic_result.one_or_none() or 0 + + # Count messages in all topics of this session + message_count_stmt = ( + select(func.count(Message.id)) # type: ignore[arg-type] + .select_from(Message) + .join(Topic, col(Message.topic_id) == col(Topic.id)) + .where(col(Topic.session_id) == session_id) + ) + message_result = await self.db.exec(message_count_stmt) + message_count = message_result.one_or_none() or 0 + + # Aggregate tokens from consume table + token_stmt = select( + func.coalesce(func.sum(ConsumeRecord.input_tokens), 0).label("input_tokens"), + func.coalesce(func.sum(ConsumeRecord.output_tokens), 0).label("output_tokens"), + ).where( + and_( + col(ConsumeRecord.session_id) == session_id, + col(ConsumeRecord.consume_state) == "success", + ) + ) + token_result = await self.db.exec(token_stmt) + token_row = token_result.first() + + return SessionStatsRead( + session_id=session_id, + agent_id=session.agent_id, + topic_count=int(topic_count), + message_count=int(message_count), + input_tokens=int(token_row[0]) if token_row else 0, # type: ignore[arg-type] + output_tokens=int(token_row[1]) if token_row else 0, # type: ignore[arg-type] + ) + + async def get_agent_stats(self, agent_id: UUID, user_id: str) -> AgentStatsAggregated: + """ + Get aggregated stats for an agent across all user's sessions. + + Uses efficient database aggregation with JOIN to compute all counts + in minimal queries (2 queries: counts + tokens). + """ + # Single query to get session_count, topic_count, message_count using JOINs + stats_stmt = ( + select( + func.count(func.distinct(Session.id)).label("session_count"), + func.count(func.distinct(Topic.id)).label("topic_count"), + func.count(Message.id).label("message_count"), # type: ignore[arg-type] + ) + .select_from(Session) + .outerjoin(Topic, col(Topic.session_id) == col(Session.id)) + .outerjoin(Message, col(Message.topic_id) == col(Topic.id)) + .where( + and_( + col(Session.agent_id) == agent_id, + col(Session.user_id) == user_id, + col(Session.is_active) == True, # noqa: E712 + ) + ) + ) + stats_result = await self.db.exec(stats_stmt) + stats_row = stats_result.first() + + # Aggregate tokens from consume table for this agent's sessions + token_stmt = ( + select( + func.coalesce(func.sum(ConsumeRecord.input_tokens), 0).label("input_tokens"), + func.coalesce(func.sum(ConsumeRecord.output_tokens), 0).label("output_tokens"), + ) + .select_from(ConsumeRecord) + .join(Session, col(ConsumeRecord.session_id) == col(Session.id)) + .where( + and_( + col(Session.agent_id) == agent_id, + col(Session.user_id) == user_id, + col(ConsumeRecord.consume_state) == "success", + ) + ) + ) + token_result = await self.db.exec(token_stmt) + token_row = token_result.first() + + return AgentStatsAggregated( + agent_id=agent_id, + session_count=int(stats_row[0]) if stats_row else 0, # type: ignore + topic_count=int(stats_row[1]) if stats_row else 0, # type: ignore + message_count=int(stats_row[2]) if stats_row else 0, # type: ignore + input_tokens=int(token_row[0]) if token_row else 0, # type: ignore + output_tokens=int(token_row[1]) if token_row else 0, # type: ignore + ) + + async def get_all_agent_stats_for_user(self, user_id: str) -> dict[str, AgentStatsAggregated]: + """ + Get aggregated stats for all agents a user has used. + + Uses efficient database aggregation with GROUP BY to compute all stats + in just two queries (one for counts, one for tokens). + This is more efficient than querying per-agent. + """ + # Single query to get all agent stats with GROUP BY using JOINs + stats_stmt = ( + select( + Session.agent_id, + func.count(func.distinct(Session.id)).label("session_count"), + func.count(func.distinct(Topic.id)).label("topic_count"), + func.count(Message.id).label("message_count"), # type: ignore[arg-type] + ) + .select_from(Session) + .outerjoin(Topic, col(Topic.session_id) == col(Session.id)) + .outerjoin(Message, col(Message.topic_id) == col(Topic.id)) + .where( + and_( + col(Session.user_id) == user_id, + col(Session.agent_id).isnot(None), + col(Session.is_active) == True, # noqa: E712 + ) + ) + .group_by(col(Session.agent_id)) + ) + stats_result = await self.db.exec(stats_stmt) + stats_rows = list(stats_result.all()) + + if not stats_rows: + return {} + + # Single query to get token aggregates grouped by agent + token_stmt = ( + select( + Session.agent_id, + func.coalesce(func.sum(ConsumeRecord.input_tokens), 0).label("input_tokens"), + func.coalesce(func.sum(ConsumeRecord.output_tokens), 0).label("output_tokens"), + ) + .select_from(ConsumeRecord) + .join(Session, col(ConsumeRecord.session_id) == col(Session.id)) + .where( + and_( + col(Session.user_id) == user_id, + col(Session.agent_id).isnot(None), + col(ConsumeRecord.consume_state) == "success", + ) + ) + .group_by(col(Session.agent_id)) + ) + token_result = await self.db.exec(token_stmt) + token_rows = list(token_result.all()) + + # Build token lookup dict + token_by_agent: dict[UUID, tuple[int, int]] = {} + for row in token_rows: + agent_id = row[0] + if agent_id: + token_by_agent[agent_id] = (int(row[1]), int(row[2])) # type: ignore + + # Build result dict + result: dict[str, AgentStatsAggregated] = {} + for row in stats_rows: # type: ignore + agent_id = row[0] + if agent_id: + input_tokens, output_tokens = token_by_agent.get(agent_id, (0, 0)) + result[str(agent_id)] = AgentStatsAggregated( + agent_id=agent_id, + session_count=int(row[1]), # type: ignore + topic_count=int(row[2]), # type: ignore + message_count=int(row[3]), # type: ignore + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + + return result diff --git a/service/migrations/versions/16fcd09524ba_add_session_spatial_layout.py b/service/migrations/versions/16fcd09524ba_add_session_spatial_layout.py new file mode 100644 index 00000000..48f7709b --- /dev/null +++ b/service/migrations/versions/16fcd09524ba_add_session_spatial_layout.py @@ -0,0 +1,33 @@ +"""add session spatial_layout + +Revision ID: 16fcd09524ba +Revises: c712246c7034 +Create Date: 2026-01-14 23:53:36.058092 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "16fcd09524ba" +down_revision: Union[str, Sequence[str], None] = "c712246c7034" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("session", sa.Column("spatial_layout", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("session", "spatial_layout") + # ### end Alembic commands ### diff --git a/service/migrations/versions/e427ec7ce799_add_session_status.py b/service/migrations/versions/e427ec7ce799_add_session_status.py new file mode 100644 index 00000000..d207efc2 --- /dev/null +++ b/service/migrations/versions/e427ec7ce799_add_session_status.py @@ -0,0 +1,30 @@ +"""add session status + +Revision ID: e427ec7ce799 +Revises: 16fcd09524ba +Create Date: 2026-01-15 00:51:11.625488 + +""" + +from typing import Sequence, Union + + +# revision identifiers, used by Alembic. +revision: str = "e427ec7ce799" +down_revision: Union[str, Sequence[str], None] = "16fcd09524ba" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/web/src/app/chat/SpatialWorkspace.tsx b/web/src/app/chat/SpatialWorkspace.tsx index caf16113..d24357c1 100644 --- a/web/src/app/chat/SpatialWorkspace.tsx +++ b/web/src/app/chat/SpatialWorkspace.tsx @@ -1,6 +1,5 @@ import { Background, - Node, ReactFlow, ReactFlowProvider, useEdgesState, @@ -10,83 +9,75 @@ import { import "@xyflow/react/dist/style.css"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import AddAgentModal from "@/components/modals/AddAgentModal"; +import { useXyzen } from "@/store"; +import type { + AgentSpatialLayout, + AgentStatsAggregated, + AgentWithLayout, +} from "@/types/agents"; import { AnimatePresence } from "framer-motion"; +import { AddAgentButton } from "./spatial/AddAgentButton"; import { AgentNode } from "./spatial/AgentNode"; import { FocusedView } from "./spatial/FocusedView"; -import type { AgentData, FlowAgentNodeData } from "./spatial/types"; - -type AgentFlowNode = Node; - -// --- Mock Data --- -const INITIAL_AGENTS: AgentData[] = [ - { - name: "Market Analyst Pro", - role: "Market Analyst", - desc: "Expert in trend forecasting", - avatar: "https://api.dicebear.com/7.x/avataaars/svg?seed=Market", - status: "busy", - size: "large", - }, - { - name: "Creative Writer", - role: "Copywriter", - desc: "Marketing copy & storytelling", - avatar: "https://api.dicebear.com/7.x/avataaars/svg?seed=Creative", - status: "idle", - size: "medium", - }, - { - name: "Global Search", - role: "Researcher", - desc: "Real-time info retrieval", - avatar: "https://api.dicebear.com/7.x/avataaars/svg?seed=Search", - status: "idle", - size: "small", - }, - { - name: "Code Auditor", - role: "Security", - desc: "Python/JS security checks", - avatar: "https://api.dicebear.com/7.x/avataaars/svg?seed=Code", - status: "idle", - size: "medium", - }, -]; - -const noopFocus: FlowAgentNodeData["onFocus"] = () => {}; - -const INITIAL_NODES: AgentFlowNode[] = [ - { - id: "1", - type: "agent", - position: { x: 0, y: 0 }, - data: { ...INITIAL_AGENTS[0], onFocus: noopFocus }, - }, - { - id: "2", - type: "agent", - position: { x: 600, y: -200 }, - data: { ...INITIAL_AGENTS[1], onFocus: noopFocus }, - }, - { - id: "3", - type: "agent", - position: { x: -300, y: 400 }, - data: { ...INITIAL_AGENTS[2], onFocus: noopFocus }, - }, - { - id: "4", +import { + SaveStatusIndicator, + type SaveStatus, +} from "./spatial/SaveStatusIndicator"; +import type { + AgentData, + AgentFlowNode, + AgentStatsDisplay, + FlowAgentNodeData, +} from "./spatial/types"; + +/** + * Convert AgentWithLayout to AgentFlowNode for ReactFlow rendering. + * Role defaults to first line of description for UI display. + * stats is derived from agentStats for visualization. + */ +const agentToFlowNode = ( + agent: AgentWithLayout, + stats?: AgentStatsAggregated, +): AgentFlowNode => { + const statsDisplay: AgentStatsDisplay | undefined = stats + ? { + messageCount: stats.message_count, + topicCount: stats.topic_count, + inputTokens: stats.input_tokens, + outputTokens: stats.output_tokens, + } + : undefined; + + return { + id: agent.id, type: "agent", - position: { x: 700, y: 500 }, - data: { ...INITIAL_AGENTS[3], onFocus: noopFocus }, - }, -]; + position: agent.spatial_layout.position, + data: { + name: agent.name, + role: (agent.description?.split("\n")[0] || "Agent") as string, + desc: agent.description || "", + avatar: + agent.avatar || + "https://api.dicebear.com/7.x/avataaars/svg?seed=default", + status: "idle", + size: agent.spatial_layout.size || "medium", + gridSize: agent.spatial_layout.gridSize, + position: agent.spatial_layout.position, + stats: statsDisplay, + onFocus: () => {}, + } as FlowAgentNodeData, + }; +}; function InnerWorkspace() { - const [nodes, setNodes, onNodesChange] = - useNodesState(INITIAL_NODES); + const { agents, fetchAgents, updateAgentLayout, agentStats } = useXyzen(); + + const [nodes, setNodes, onNodesChange] = useNodesState([]); const [edges, , onEdgesChange] = useEdgesState([]); const [focusedAgentId, setFocusedAgentId] = useState(null); + const [saveStatus, setSaveStatus] = useState("idle"); + const [isAddModalOpen, setAddModalOpen] = useState(false); const containerRef = useRef(null); const [prevViewport, setPrevViewport] = useState<{ x: number; @@ -98,9 +89,85 @@ function InnerWorkspace() { const cancelInitialFitRef = useRef(false); const initialFitAttemptsRef = useRef(0); + // Debounce save timers + const saveTimerRef = useRef | null>(null); + const pendingSavesRef = useRef>(new Map()); + const savedTimerRef = useRef | null>(null); + + // Fetch agents on mount + useEffect(() => { + fetchAgents().catch((err) => console.error("Failed to fetch agents:", err)); + }, [fetchAgents]); + + // Debounced save function + const scheduleSave = useCallback( + (agentId: string, layout: AgentSpatialLayout) => { + pendingSavesRef.current.set(agentId, layout); + + // Clear existing timer + if (saveTimerRef.current) { + clearTimeout(saveTimerRef.current); + } + + // Set saving status + setSaveStatus("saving"); + + // Debounce: save after 800ms of no changes + saveTimerRef.current = setTimeout(async () => { + const saves = Array.from(pendingSavesRef.current.entries()); + pendingSavesRef.current.clear(); + + try { + // Save all pending layouts + await Promise.all( + saves.map(([id, layout]) => updateAgentLayout(id, layout)), + ); + + setSaveStatus("saved"); + + // Clear saved status after 2 seconds + if (savedTimerRef.current) clearTimeout(savedTimerRef.current); + savedTimerRef.current = setTimeout(() => setSaveStatus("idle"), 2000); + } catch (error) { + console.error("Failed to save layouts:", error); + setSaveStatus("failed"); + } + }, 800); + }, + [updateAgentLayout], + ); + + // Retry failed saves + const handleRetrySave = useCallback(() => { + const saves = Array.from(pendingSavesRef.current.entries()); + if (saves.length > 0) { + setSaveStatus("saving"); + Promise.all(saves.map(([id, layout]) => updateAgentLayout(id, layout))) + .then(() => { + pendingSavesRef.current.clear(); + setSaveStatus("saved"); + if (savedTimerRef.current) clearTimeout(savedTimerRef.current); + savedTimerRef.current = setTimeout(() => setSaveStatus("idle"), 2000); + }) + .catch(() => setSaveStatus("failed")); + } + }, [updateAgentLayout]); + + // Update nodes whenever agents or stats change + useEffect(() => { + if (agents.length > 0) { + const flowNodes = agents.map((agent) => { + const stats = agentStats[agent.id]; + return agentToFlowNode(agent, stats); + }); + setNodes(flowNodes); + } + }, [agents, agentStats, setNodes]); + useEffect(() => { if (didInitialFitViewRef.current) return; if (cancelInitialFitRef.current) return; + if (nodes.length === 0) return; // Don't fit empty viewport let cancelled = false; initialFitAttemptsRef.current = 0; @@ -231,7 +298,11 @@ function InnerWorkspace() { }; setNodes((prev) => { - const next = prev.map((n) => ({ ...n })); + const next = prev.map((n) => ({ + ...n, + position: { ...n.position }, + data: { ...(n.data as FlowAgentNodeData) }, + })); const moving = next.find((n) => n.id === draggedNode.id); if (!moving) return prev; @@ -287,10 +358,21 @@ function InnerWorkspace() { if (!movedThisIter) break; } + // Keep persistable position in sync for future storage. + moving.data.position = { ...moving.position }; + + // Schedule auto-save for this agent + const agentData = moving.data as FlowAgentNodeData; + scheduleSave(moving.id, { + position: moving.position, + gridSize: agentData.gridSize, + size: agentData.size, + }); + return next; }); }, - [getNode, setNodes], + [getNode, setNodes, scheduleSave], ); const focusedAgent = useMemo(() => { @@ -322,6 +404,12 @@ function InnerWorkspace() { + {/* Save Status Indicator */} + + + {/* Add Agent Button */} + setAddModalOpen(true)} /> + {focusedAgent && ( )} + + {/* Add Agent Modal */} + setAddModalOpen(false)} + />
); } diff --git a/web/src/app/chat/spatial/AddAgentButton.tsx b/web/src/app/chat/spatial/AddAgentButton.tsx new file mode 100644 index 00000000..327251b2 --- /dev/null +++ b/web/src/app/chat/spatial/AddAgentButton.tsx @@ -0,0 +1,39 @@ +/** + * AddAgentButton - Floating action button to add new agents + */ +import { PlusIcon } from "@heroicons/react/24/outline"; +import { motion } from "framer-motion"; + +interface AddAgentButtonProps { + onClick: () => void; +} + +export function AddAgentButton({ onClick }: AddAgentButtonProps) { + return ( + + + + {/* Ripple effect on hover */} + + + ); +} diff --git a/web/src/app/chat/spatial/AgentNode.tsx b/web/src/app/chat/spatial/AgentNode.tsx index 64ebb30a..367e5b33 100644 --- a/web/src/app/chat/spatial/AgentNode.tsx +++ b/web/src/app/chat/spatial/AgentNode.tsx @@ -1,13 +1,14 @@ import { Modal } from "@/components/animate-ui/components/animate/modal"; import { cn } from "@/lib/utils"; -import { Cog6ToothIcon } from "@heroicons/react/24/outline"; -import type { Node } from "@xyflow/react"; -import { NodeProps, useReactFlow } from "@xyflow/react"; +import { + ChatBubbleLeftRightIcon, + Cog6ToothIcon, + DocumentTextIcon, +} from "@heroicons/react/24/outline"; +import { useReactFlow } from "@xyflow/react"; import { motion } from "framer-motion"; import { useState } from "react"; -import type { FlowAgentNodeData } from "./types"; - -type AgentFlowNode = Node; +import type { AgentFlowNodeProps, AgentStatsDisplay } from "./types"; // Helper to calc size // Base unit: 1x1 = 200x160. Gap = 16. @@ -29,6 +30,154 @@ const getSizeStyle = (w?: number, h?: number, sizeStr?: string) => { return { width: 200, height: 160 }; }; +// Format token count for display +const formatTokenCount = (count: number): string => { + if (count >= 1000000) return `${(count / 1000000).toFixed(1)}M`; + if (count >= 1000) return `${(count / 1000).toFixed(1)}K`; + return count.toString(); +}; + +// Stats display component with responsive layout +function StatsDisplay({ + stats, + gridW, + gridH, +}: { + stats?: AgentStatsDisplay; + gridW: number; + gridH: number; +}) { + if (!stats) return null; + + const totalTokens = stats.inputTokens + stats.outputTokens; + const hasActivity = stats.messageCount > 0 || stats.topicCount > 0; + const area = gridW * gridH; + + // Compact 1x1: Only show message count as badge + if (area === 1) { + if (!hasActivity) return null; + return ( +
+ + {stats.messageCount} +
+ ); + } + + // 2x1 horizontal: Compact inline stats + if (gridW >= 2 && gridH === 1) { + return ( +
+
+ + {stats.messageCount} +
+
+ + {stats.topicCount} +
+ {totalTokens > 0 && ( +
+ {formatTokenCount(totalTokens)} tokens +
+ )} +
+ ); + } + + // 1x2 vertical: Stacked stats + if (gridW === 1 && gridH >= 2) { + return ( +
+
+ + {stats.messageCount} messages +
+
+ + {stats.topicCount} topics +
+ {totalTokens > 0 && ( +
+ {formatTokenCount(totalTokens)} tokens +
+ )} +
+ ); + } + + // 2x2 or larger: Full stats grid with visual bars + return ( +
+ {/* Stats row */} +
+
+ +
+ {stats.messageCount} +
+ msgs +
+
+ +
+ {stats.topicCount} +
+ topics +
+
+ + {/* Token usage bar */} + {totalTokens > 0 && ( +
+
+ Token Usage + {formatTokenCount(totalTokens)} +
+
+
+
+
+
+ ↓ {formatTokenCount(stats.inputTokens)} + ↑ {formatTokenCount(stats.outputTokens)} +
+
+ )} + + {/* Activity visualization for larger sizes */} + {area >= 6 && hasActivity && ( +
+ {Array.from({ length: Math.min(stats.topicCount + 2, 8) }).map( + (_, i) => ( +
+ ), + )} +
+ )} +
+ ); +} + function GridResizer({ currentW = 1, currentH = 1, @@ -83,7 +232,7 @@ function GridResizer({ ); } -export function AgentNode({ id, data, selected }: NodeProps) { +export function AgentNode({ id, data, selected }: AgentFlowNodeProps) { const { updateNodeData } = useReactFlow(); const [isSettingsOpen, setIsSettingsOpen] = useState(false); // Determine current dim @@ -134,7 +283,7 @@ export function AgentNode({ id, data, selected }: NodeProps) { )} @@ -146,7 +295,7 @@ export function AgentNode({ id, data, selected }: NodeProps) { ? "ring-2 ring-[#5a6e8c]/20 dark:ring-0 dark:border-indigo-400/50 dark:shadow-[0_0_15px_rgba(99,102,241,0.5),0_0_30px_rgba(168,85,247,0.3)] shadow-2xl" : "hover:shadow-2xl", data.isFocused && - "ring-0 !border-white/20 dark:!border-white/10 !shadow-none bg-white/90 dark:bg-black/80", // Cleaner look when focused + "ring-0 border-white/20! dark:border-white/10! shadow-none! bg-white/90 dark:bg-black/80", // Cleaner look when focused )} /> @@ -164,24 +313,24 @@ export function AgentNode({ id, data, selected }: NodeProps) {
-
+
avatar -
-
+
+
{data.name}
-
+
{data.role}
-
+
{data.status === "busy" && (
@@ -191,24 +340,12 @@ export function AgentNode({ id, data, selected }: NodeProps) {
- {/* Dynamic Abstract Viz based on size */} - {((data.gridSize && data.gridSize.w * data.gridSize.h >= 2) || - data.size === "large" || - data.size === "medium") && ( -
-
- {[40, 70, 55, 90, 60, 80] - .slice(0, currentW * currentH + 2) - .map((h, i) => ( -
- ))} -
-
- )} + {/* Stats Display - Responsive to grid size */} +
diff --git a/web/src/app/chat/spatial/SaveStatusIndicator.tsx b/web/src/app/chat/spatial/SaveStatusIndicator.tsx new file mode 100644 index 00000000..18b9e15b --- /dev/null +++ b/web/src/app/chat/spatial/SaveStatusIndicator.tsx @@ -0,0 +1,75 @@ +/** + * SaveStatusIndicator - Shows auto-save status in the top-right corner + */ +import { + CheckIcon, + ExclamationTriangleIcon, +} from "@heroicons/react/24/outline"; +import { AnimatePresence, motion } from "framer-motion"; + +export type SaveStatus = "idle" | "saving" | "saved" | "failed"; + +interface SaveStatusIndicatorProps { + status: SaveStatus; + onRetry?: () => void; +} + +export function SaveStatusIndicator({ + status, + onRetry, +}: SaveStatusIndicatorProps) { + if (status === "idle") return null; + + return ( + + + {status === "saving" && ( +
+ + + Saving... + +
+ )} + + {status === "saved" && ( + + + + Saved + + + )} + + {status === "failed" && ( + + + + Failed - Click to retry + + + )} +
+
+ ); +} diff --git a/web/src/app/chat/spatial/types.ts b/web/src/app/chat/spatial/types.ts index face384c..88a292a8 100644 --- a/web/src/app/chat/spatial/types.ts +++ b/web/src/app/chat/spatial/types.ts @@ -1,16 +1,48 @@ +import type { Node, NodeProps } from "@xyflow/react"; + +export type XYPosition = { x: number; y: number }; +export type GridSize = { w: number; h: number }; +export type AgentWidgetSize = "large" | "medium" | "small"; + +/** + * Stats data for agent display. + */ +export interface AgentStatsDisplay { + messageCount: number; + topicCount: number; + inputTokens: number; + outputTokens: number; +} + +/** + * Persistable agent widget data (no functions). + * + * Note: XYFlow stores position on the Node itself, but we also keep a copy here + * so it can be persisted/serialized without the full Node shape. + */ export interface AgentData { name: string; role: string; desc: string; avatar: string; status: "idle" | "busy" | "offline"; - size: "large" | "medium" | "small"; - gridSize?: { w: number; h: number }; // 1-3 grid system + size: AgentWidgetSize; + gridSize?: GridSize; // 1-3 grid system + position: XYPosition; + /** Stats for display visualization */ + stats?: AgentStatsDisplay; } -export interface AgentNodeData extends AgentData { + +/** Runtime-only fields injected by the workspace. */ +export interface AgentNodeRuntimeData { onFocus: (id: string) => void; isFocused?: boolean; } +export type AgentNodeData = AgentData & AgentNodeRuntimeData; + // XYFlow requires node.data to be a Record export type FlowAgentNodeData = AgentNodeData & Record; + +export type AgentFlowNode = Node; +export type AgentFlowNodeProps = NodeProps; diff --git a/web/src/core/session/types.ts b/web/src/core/session/types.ts index 9deb50f9..6372f71d 100644 --- a/web/src/core/session/types.ts +++ b/web/src/core/session/types.ts @@ -16,6 +16,7 @@ export interface SessionResponse { model?: string; model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; + spatial_layout?: Record | null; created_at: string; updated_at: string; topics?: TopicResponse[]; diff --git a/web/src/service/sessionService.ts b/web/src/service/sessionService.ts index 2d8feade..db593436 100644 --- a/web/src/service/sessionService.ts +++ b/web/src/service/sessionService.ts @@ -1,5 +1,6 @@ import { authService } from "@/service/authService"; import { useXyzen } from "@/store"; +import type { AgentSpatialLayout } from "@/types/agents"; export interface SessionCreate { name: string; @@ -10,6 +11,7 @@ export interface SessionCreate { model?: string; model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; + spatial_layout?: AgentSpatialLayout; } export interface SessionUpdate { @@ -20,6 +22,7 @@ export interface SessionUpdate { model?: string; model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; + spatial_layout?: AgentSpatialLayout; } export interface SessionRead { @@ -33,6 +36,7 @@ export interface SessionRead { model?: string; model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; + spatial_layout?: AgentSpatialLayout; created_at: string; updated_at: string; } diff --git a/web/src/store/slices/agentSlice.ts b/web/src/store/slices/agentSlice.ts index acb3a87c..732ebf84 100644 --- a/web/src/store/slices/agentSlice.ts +++ b/web/src/store/slices/agentSlice.ts @@ -1,18 +1,35 @@ import { authService } from "@/service/authService"; -import type { Agent, SystemAgentTemplate } from "@/types/agents"; +import { sessionService } from "@/service/sessionService"; +import type { + Agent, + AgentSpatialLayout, + AgentStatsAggregated, + AgentWithLayout, + SystemAgentTemplate, +} from "@/types/agents"; import type { StateCreator } from "zustand"; import type { XyzenState } from "../types"; export interface AgentSlice { - agents: Agent[]; + agents: AgentWithLayout[]; agentsLoading: boolean; + // Map from agentId -> sessionId for layout persistence + // Layout is stored in Session, not Agent + sessionIdByAgentId: Record; + + // Agent stats for growth visualization (aggregated from sessions/messages) + agentStats: Record; + agentStatsLoading: boolean; + // System agent templates systemAgentTemplates: SystemAgentTemplate[]; templatesLoading: boolean; fetchSystemAgentTemplates: () => Promise; fetchAgents: () => Promise; + fetchAgentStats: () => Promise; + incrementLocalAgentMessageCount: (agentId: string) => void; isCreatingAgent: boolean; createAgent: (agent: Omit) => Promise; @@ -20,7 +37,11 @@ export interface AgentSlice { systemKey: string, customName?: string, ) => Promise; - updateAgent: (agent: Agent) => Promise; + updateAgent: (agent: Agent | AgentWithLayout) => Promise; + updateAgentLayout: ( + agentId: string, + layout: AgentSpatialLayout, + ) => Promise; updateAgentProvider: ( agentId: string, providerId: string | null, @@ -28,6 +49,20 @@ export interface AgentSlice { deleteAgent: (id: string) => Promise; } +const defaultSpatialLayoutForIndex = ( + index: number, +): AgentWithLayout["spatial_layout"] => { + // Simple deterministic grid so the spatial UI has stable starting positions. + // Default to 2x1 grid size for compact horizontal layout + const col = index % 3; + const row = Math.floor(index / 3); + return { + position: { x: col * 360, y: row * 220 }, + size: "medium", + gridSize: { w: 2, h: 1 }, + }; +}; + // 创建带认证头的请求选项 const createAuthHeaders = (): HeadersInit => { const token = authService.getToken(); @@ -51,6 +86,13 @@ export const createAgentSlice: StateCreator< agents: [], agentsLoading: false, + // Map from agentId -> sessionId + sessionIdByAgentId: {}, + + // Agent stats state + agentStats: {}, + agentStatsLoading: false, + // System agent templates state systemAgentTemplates: [], templatesLoading: false, @@ -91,8 +133,43 @@ export const createAgentSlice: StateCreator< throw new Error("Failed to fetch agents"); } - const agents: Agent[] = await response.json(); - set({ agents, agentsLoading: false }); + const rawAgents: Agent[] = await response.json(); + + // Fetch sessions for each agent to get spatial_layout + // Build session mapping and extract layouts + const sessionMap: Record = {}; + const layoutMap: Record = {}; + + await Promise.all( + rawAgents.map(async (agent) => { + try { + const session = await sessionService.getSessionByAgent(agent.id); + sessionMap[agent.id] = session.id; + if (session.spatial_layout) { + layoutMap[agent.id] = session.spatial_layout; + } + } catch { + // Session doesn't exist yet - will be created when user starts chat + console.debug(`No session found for agent ${agent.id}`); + } + }), + ); + + // Enrich agents with layout from session or default + const agents: AgentWithLayout[] = rawAgents.map((agent, index) => ({ + ...agent, + spatial_layout: + layoutMap[agent.id] ?? defaultSpatialLayoutForIndex(index), + })); + + set({ + agents, + agentsLoading: false, + sessionIdByAgentId: sessionMap, + }); + + // Also fetch stats for growth visualization + get().fetchAgentStats(); } catch (error) { console.error("Failed to fetch agents:", error); set({ agentsLoading: false }); @@ -100,6 +177,56 @@ export const createAgentSlice: StateCreator< } }, + fetchAgentStats: async () => { + set({ agentStatsLoading: true }); + try { + const response = await fetch( + `${get().backendUrl}/xyzen/api/v1/agents/stats`, + { + headers: createAuthHeaders(), + }, + ); + + if (!response.ok) { + throw new Error("Failed to fetch agent stats"); + } + + const stats: Record = await response.json(); + set({ agentStats: stats, agentStatsLoading: false }); + } catch (error) { + console.error("Failed to fetch agent stats:", error); + set({ agentStatsLoading: false }); + // Don't throw - stats are optional enhancement + } + }, + + /** + * Optimistically increment the local message count for an agent. + * Used for immediate UI feedback when a message is sent. + * Actual stats will sync from backend on next fetchAgentStats(). + */ + incrementLocalAgentMessageCount: (agentId) => { + set((state) => { + const existingStats = state.agentStats[agentId]; + if (existingStats) { + state.agentStats[agentId] = { + ...existingStats, + message_count: existingStats.message_count + 1, + }; + } else { + // Create placeholder stats if not yet fetched + state.agentStats[agentId] = { + agent_id: agentId, + session_count: 0, + topic_count: 0, + message_count: 1, + input_tokens: 0, + output_tokens: 0, + }; + } + }); + }, + createAgent: async (agent) => { const { isCreatingAgent } = get(); if (isCreatingAgent) { @@ -189,6 +316,64 @@ export const createAgentSlice: StateCreator< } }, + updateAgentLayout: async (agentId, layout) => { + try { + // Get the session ID for this agent + let sessionId = get().sessionIdByAgentId[agentId]; + + if (!sessionId) { + // Try to fetch the session if not cached + try { + const session = await sessionService.getSessionByAgent(agentId); + sessionId = session.id; + // Cache it + set((state) => { + state.sessionIdByAgentId[agentId] = sessionId; + }); + } catch { + // Session doesn't exist - create one first + // This happens when user drags an agent that hasn't been used yet + console.warn( + `No session found for agent ${agentId}, creating one...`, + ); + const agent = get().agents.find((a) => a.id === agentId); + const newSession = await sessionService.createSession({ + name: agent?.name ?? "Agent Session", + agent_id: agentId, + spatial_layout: layout, + }); + sessionId = newSession.id; + set((state) => { + state.sessionIdByAgentId[agentId] = sessionId; + }); + + // Update local state optimistically + set((state) => { + const agentData = state.agents.find((a) => a.id === agentId); + if (agentData) { + agentData.spatial_layout = layout; + } + }); + return; + } + } + + // Update the session's spatial_layout via Session API + await sessionService.updateSession(sessionId, { spatial_layout: layout }); + + // Update local state optimistically + set((state) => { + const agent = state.agents.find((a) => a.id === agentId); + if (agent) { + agent.spatial_layout = layout; + } + }); + } catch (error) { + console.error("Failed to update agent layout:", error); + throw error; + } + }, + updateAgentProvider: async (agentId, providerId) => { try { const response = await fetch( diff --git a/web/src/store/types.ts b/web/src/store/types.ts index 4bedff71..77a88dbd 100644 --- a/web/src/store/types.ts +++ b/web/src/store/types.ts @@ -1,3 +1,4 @@ +import type { AgentExecutionState } from "@/types/agentEvents"; import type { AgentSlice, AuthSlice, @@ -9,7 +10,6 @@ import type { ProviderSlice, UiSlice, } from "./slices"; -import type { AgentExecutionState } from "@/types/agentEvents"; // 定义应用中的核心类型 export interface ToolCall { @@ -203,6 +203,7 @@ export interface SessionResponse { model?: string; model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; + spatial_layout?: Record | null; topics: TopicResponse[]; } diff --git a/web/src/types/agents.ts b/web/src/types/agents.ts index 6505d79b..dfe54bc4 100644 --- a/web/src/types/agents.ts +++ b/web/src/types/agents.ts @@ -1,5 +1,50 @@ // Agent type definitions and type guards +// Spatial/layout primitives (used by spatial chat UI) +export type XYPosition = { x: number; y: number }; +export type GridSize = { w: number; h: number }; +export type AgentWidgetSize = "large" | "medium" | "small"; + +export interface AgentSpatialLayout { + position: XYPosition; + gridSize?: GridSize; + size?: AgentWidgetSize; +} + +/** + * Aggregated stats for an agent (computed from sessions/messages/consume, not stored). + * This matches the backend AgentStatsAggregated schema. + */ +export interface AgentStatsAggregated { + agent_id: string; + session_count: number; + topic_count: number; + message_count: number; + input_tokens: number; + output_tokens: number; +} + +/** + * Calculate the visual scale multiplier based on message count. + * Uses a logarithmic curve for diminishing returns: + * - Each message adds 1/1000 growth initially + * - Growth rate diminishes as count increases + * - Capped at 2x size to prevent UI overflow + * + * Formula: scale = 1 + 0.3 * ln(1 + messageCount / 100) + * - At 0 messages: scale = 1.0 + * - At 100 messages: scale = 1.21 + * - At 500 messages: scale = 1.54 + * - At 1000 messages: scale = 1.69 + * - Asymptotically approaches 2.0 + */ +export const calculateGrowthScale = (messageCount: number): number => { + // Logarithmic growth with diminishing returns + const scale = 1 + 0.3 * Math.log(1 + messageCount / 100); + // Cap at 2x to prevent overflow + return Math.min(scale, 2.0); +}; + // Metadata for a system agent template export interface SystemAgentMetadata { name: string; @@ -56,6 +101,16 @@ export interface Agent { graph_config?: Record | null; } +/** + * UI-enriched agent type. + * Note: `spatial_layout` is a frontend concern and is not sent to the Agent API. + */ +export type AgentWithLayout = Agent & { + spatial_layout: AgentSpatialLayout; + /** Aggregated usage stats for growth visualization; fetched from /agents/stats API */ + stats?: AgentStatsAggregated; +}; + // System/builtin agents (official agents provided by the platform) export interface SystemAgent extends Agent { is_official?: boolean; From 438b158ca7a10ca3a08c07204b10a3e67b565d01 Mon Sep 17 00:00:00 2001 From: Harvey Date: Thu, 15 Jan 2026 02:29:15 +0800 Subject: [PATCH 07/11] feat: Add EditAgentModal and SessionSettingsModal for agent management - Introduced EditAgentModal for editing agent details within the workspace. - Added SessionSettingsModal to manage avatar and grid size settings for agents. - Enhanced agent data structure to include sessionId and avatar properties. - Implemented daily activity tracking and yesterday's summary for agents. - Updated agent layout handling to support avatar changes and layout adjustments. - Refactored AgentNode to integrate new settings and avatar functionalities. - Improved SaveStatusIndicator positioning for better visibility. - Updated session service to handle avatar updates and session creation. - Enhanced agent slice to manage avatar state and session associations. --- service/app/api/v1/__init__.py | 2 + service/app/api/v1/agents.py | 74 +++- service/app/api/v1/avatar.py | 98 +++++ service/app/api/v1/sessions.py | 10 +- service/app/models/session_stats.py | 24 ++ service/app/models/sessions.py | 7 + service/app/repos/session_stats.py | 130 +++++- .../efa73edfbb15_add_session_daily_status.py | 34 ++ web/src/app/chat/SpatialWorkspace.tsx | 65 ++- web/src/app/chat/spatial/AgentNode.tsx | 373 +++++++++++------ .../app/chat/spatial/SaveStatusIndicator.tsx | 2 +- web/src/app/chat/spatial/types.ts | 28 ++ .../modals/SessionSettingsModal.tsx | 390 ++++++++++++++++++ web/src/service/sessionService.ts | 3 + web/src/store/slices/agentSlice.ts | 87 +++- web/src/types/agents.ts | 26 ++ 16 files changed, 1214 insertions(+), 139 deletions(-) create mode 100644 service/app/api/v1/avatar.py create mode 100644 service/migrations/versions/efa73edfbb15_add_session_daily_status.py create mode 100644 web/src/components/modals/SessionSettingsModal.tsx diff --git a/service/app/api/v1/__init__.py b/service/app/api/v1/__init__.py index b068e925..210cf7ca 100644 --- a/service/app/api/v1/__init__.py +++ b/service/app/api/v1/__init__.py @@ -3,6 +3,7 @@ from .agents import router as agents_router from .auth import router as auth_router +from .avatar import router as avatar_router from .checkin import router as checkin_router from .files import router as files_router from .folders import router as folders_router @@ -89,3 +90,4 @@ async def root() -> RootResponse: v1_router.include_router(folders_router, prefix="/folders") v1_router.include_router(knowledge_sets_router, prefix="/knowledge-sets") v1_router.include_router(marketplace_router, prefix="/marketplace") +v1_router.include_router(avatar_router, prefix="/avatar") diff --git a/service/app/api/v1/agents.py b/service/app/api/v1/agents.py index 1e4a3013..bc4be9c1 100644 --- a/service/app/api/v1/agents.py +++ b/service/app/api/v1/agents.py @@ -25,7 +25,7 @@ from app.infra.database import get_session from app.middleware.auth import get_current_user from app.models.agent import AgentCreate, AgentRead, AgentReadWithDetails, AgentScope, AgentUpdate -from app.models.session_stats import AgentStatsAggregated +from app.models.session_stats import AgentStatsAggregated, DailyStatsResponse, YesterdaySummary from app.repos import AgentRepository, KnowledgeSetRepository, ProviderRepository from app.repos.agent_marketplace import AgentMarketplaceRepository from app.repos.session import SessionRepository @@ -232,6 +232,78 @@ async def get_all_agent_stats( return await stats_repo.get_all_agent_stats_for_user(user) +@router.get("/stats/{agent_id}/daily", response_model=DailyStatsResponse) +async def get_agent_daily_stats( + agent_id: str, + days: int = 7, + user: str = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> DailyStatsResponse: + """ + Get daily message counts for an agent's sessions over the last N days. + + Useful for activity visualization charts. Returns counts for each day, + including days with zero activity. + + Args: + agent_id: Agent identifier (UUID string or builtin agent ID) + days: Number of days to include (default: 7) + user: Authenticated user ID (injected by dependency) + db: Database session (injected by dependency) + + Returns: + DailyStatsResponse: Daily message counts for the agent + """ + from app.models.sessions import builtin_agent_id_to_uuid + + # Resolve agent ID to UUID + if agent_id.startswith("builtin_"): + agent_uuid = builtin_agent_id_to_uuid(agent_id) + else: + try: + agent_uuid = UUID(agent_id) + except ValueError: + raise HTTPException(status_code=400, detail=f"Invalid agent ID format: '{agent_id}'") + + stats_repo = SessionStatsRepository(db) + return await stats_repo.get_daily_stats_for_agent(agent_uuid, user, days) + + +@router.get("/stats/{agent_id}/yesterday", response_model=YesterdaySummary) +async def get_agent_yesterday_summary( + agent_id: str, + user: str = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> YesterdaySummary: + """ + Get yesterday's activity summary for an agent's sessions. + + Returns the message count and optionally a preview of the last message. + Useful for displaying "You had X conversations yesterday" type summaries. + + Args: + agent_id: Agent identifier (UUID string or builtin agent ID) + user: Authenticated user ID (injected by dependency) + db: Database session (injected by dependency) + + Returns: + YesterdaySummary: Yesterday's activity summary + """ + from app.models.sessions import builtin_agent_id_to_uuid + + # Resolve agent ID to UUID + if agent_id.startswith("builtin_"): + agent_uuid = builtin_agent_id_to_uuid(agent_id) + else: + try: + agent_uuid = UUID(agent_id) + except ValueError: + raise HTTPException(status_code=400, detail=f"Invalid agent ID format: '{agent_id}'") + + stats_repo = SessionStatsRepository(db) + return await stats_repo.get_yesterday_summary_for_agent(agent_uuid, user) + + @router.get("/system/chat", response_model=AgentReadWithDetails) async def get_system_chat_agent( user: str = Depends(get_current_user), diff --git a/service/app/api/v1/avatar.py b/service/app/api/v1/avatar.py new file mode 100644 index 00000000..fc396e0c --- /dev/null +++ b/service/app/api/v1/avatar.py @@ -0,0 +1,98 @@ +""" +Avatar Proxy API. + +Proxies DiceBear avatar requests through the backend for better +accessibility in regions with slow access to api.dicebear.com. + +The avatars are SVG format and very small (~2-5KB), so caching +at CDN/browser level is usually sufficient. +""" + +import httpx +from fastapi import APIRouter, HTTPException, Query +from fastapi.responses import Response + +router = APIRouter(tags=["avatar"]) + +# DiceBear API base URL +DICEBEAR_BASE = "https://api.dicebear.com/9.x" + +# Allowed styles to prevent abuse +ALLOWED_STYLES = { + "adventurer", + "avataaars", + "bottts", + "fun-emoji", + "lorelei", + "micah", + "miniavs", + "notionists", + "open-peeps", + "personas", + "pixel-art", + "shapes", + "thumbs", +} + +# HTTP client with connection pooling +_client: httpx.AsyncClient | None = None + + +async def get_client() -> httpx.AsyncClient: + global _client + if _client is None: + _client = httpx.AsyncClient(timeout=10.0) + return _client + + +@router.get("/{style}/svg") +async def proxy_avatar( + style: str, + seed: str = Query(..., description="Seed for generating the avatar"), +) -> Response: + """ + Proxy DiceBear avatar generation. + + This endpoint proxies requests to api.dicebear.com for better + accessibility in regions with slow international connectivity. + + Args: + style: Avatar style (e.g., avataaars, bottts, pixel-art) + seed: Seed string for deterministic avatar generation + + Returns: + SVG image response with appropriate caching headers + """ + if style not in ALLOWED_STYLES: + raise HTTPException( + status_code=400, + detail=f"Invalid style. Allowed: {', '.join(sorted(ALLOWED_STYLES))}", + ) + + # Build DiceBear URL + url = f"{DICEBEAR_BASE}/{style}/svg?seed={seed}" + + try: + client = await get_client() + response = await client.get(url) + response.raise_for_status() + + # Return SVG with long cache headers (avatars are deterministic) + return Response( + content=response.content, + media_type="image/svg+xml", + headers={ + "Cache-Control": "public, max-age=31536000, immutable", # 1 year + "Access-Control-Allow-Origin": "*", + }, + ) + except httpx.HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, + detail=f"DiceBear API error: {e.response.text}", + ) + except httpx.RequestError as e: + raise HTTPException( + status_code=502, + detail=f"Failed to fetch avatar: {str(e)}", + ) diff --git a/service/app/api/v1/sessions.py b/service/app/api/v1/sessions.py index f4ef323a..990621c2 100644 --- a/service/app/api/v1/sessions.py +++ b/service/app/api/v1/sessions.py @@ -175,7 +175,15 @@ async def update_session( Returns: SessionRead: The updated session """ + import logging + + logger = logging.getLogger(__name__) + logger.info( + f"[DEBUG] update_session called: session_id={session_id}, data={session_data.model_dump(exclude_unset=True)}" + ) try: - return await SessionService(db).update_session(session_id, session_data, user) + result = await SessionService(db).update_session(session_id, session_data, user) + logger.info(f"[DEBUG] update_session success: avatar={result.avatar}") + return result except ErrCodeError as e: raise handle_auth_error(e) diff --git a/service/app/models/session_stats.py b/service/app/models/session_stats.py index 92946606..d1aa8932 100644 --- a/service/app/models/session_stats.py +++ b/service/app/models/session_stats.py @@ -9,6 +9,7 @@ The schemas here are used for API responses only. """ +from datetime import date from uuid import UUID from pydantic import BaseModel @@ -46,3 +47,26 @@ class UserStatsAggregated(BaseModel): message_count: int = 0 input_tokens: int = 0 output_tokens: int = 0 + + +class DailyMessageCount(BaseModel): + """Message count for a specific day.""" + + date: date + message_count: int + + +class DailyStatsResponse(BaseModel): + """Daily activity stats for a session/agent (last N days).""" + + agent_id: UUID + daily_counts: list[DailyMessageCount] + + +class YesterdaySummary(BaseModel): + """Summary of yesterday's activity for a session.""" + + agent_id: UUID + message_count: int + last_message_content: str | None = None + summary: str | None = None # Optional AI-generated summary diff --git a/service/app/models/sessions.py b/service/app/models/sessions.py index 8daf47d9..efe816c3 100644 --- a/service/app/models/sessions.py +++ b/service/app/models/sessions.py @@ -79,6 +79,11 @@ class SessionBase(SQLModel): google_search_enabled: bool = Field( default=False, description="Enable built-in web search for supported models (e.g., Gemini)" ) + avatar: str | None = Field( + default=None, + max_length=500, + description="Session-specific avatar URL or DiceBear seed (e.g., 'dicebear:adventurer:seed123' or full URL)", + ) spatial_layout: dict[str, Any] | None = Field( default=None, sa_column=Column(JSON, nullable=True), @@ -107,6 +112,7 @@ class SessionCreate(SQLModel): model: str | None = None model_tier: ModelTier | None = None google_search_enabled: bool = False + avatar: str | None = None spatial_layout: dict[str, Any] | None = None @@ -131,4 +137,5 @@ class SessionUpdate(SQLModel): model: str | None = None model_tier: ModelTier | None = None google_search_enabled: bool | None = None + avatar: str | None = None spatial_layout: dict[str, Any] | None = None diff --git a/service/app/repos/session_stats.py b/service/app/repos/session_stats.py index 43678afa..ce3d1382 100644 --- a/service/app/repos/session_stats.py +++ b/service/app/repos/session_stats.py @@ -11,6 +11,7 @@ """ import logging +from datetime import date, datetime, timedelta, timezone from uuid import UUID from sqlalchemy import and_, func @@ -19,7 +20,13 @@ from app.models.consume import ConsumeRecord from app.models.message import Message -from app.models.session_stats import AgentStatsAggregated, SessionStatsRead +from app.models.session_stats import ( + AgentStatsAggregated, + DailyMessageCount, + DailyStatsResponse, + SessionStatsRead, + YesterdaySummary, +) from app.models.sessions import Session from app.models.topic import Topic @@ -215,3 +222,124 @@ async def get_all_agent_stats_for_user(self, user_id: str) -> dict[str, AgentSta ) return result + + async def get_daily_stats_for_agent(self, agent_id: UUID, user_id: str, days: int = 7) -> DailyStatsResponse: + """ + Get daily message counts for an agent's sessions over the last N days. + + Returns a list of (date, message_count) for activity visualization. + """ + # Calculate date range + today = datetime.now(timezone.utc).date() + start_date = today - timedelta(days=days - 1) + start_datetime = datetime.combine(start_date, datetime.min.time()).replace(tzinfo=timezone.utc) + + # Query messages grouped by date + daily_stmt = ( + select( + func.date(Message.created_at).label("day"), + func.count(Message.id).label("message_count"), # type: ignore[arg-type] + ) + .select_from(Message) + .join(Topic, col(Message.topic_id) == col(Topic.id)) + .join(Session, col(Topic.session_id) == col(Session.id)) + .where( + and_( + col(Session.agent_id) == agent_id, + col(Session.user_id) == user_id, + col(Message.created_at) >= start_datetime, + ) + ) + .group_by(func.date(Message.created_at)) + .order_by(func.date(Message.created_at)) + ) + daily_result = await self.db.exec(daily_stmt) + daily_rows = list(daily_result.all()) + + # Build a map of date -> count + count_by_date: dict[date, int] = {} + for row in daily_rows: + day = row[0] # type: ignore + count = int(row[1]) # type: ignore + if isinstance(day, str): + day = datetime.strptime(day, "%Y-%m-%d").date() + count_by_date[day] = count + + # Fill in all days (including zeros) + daily_counts: list[DailyMessageCount] = [] + for i in range(days): + day = start_date + timedelta(days=i) + daily_counts.append( + DailyMessageCount( + date=day, + message_count=count_by_date.get(day, 0), + ) + ) + + return DailyStatsResponse( + agent_id=agent_id, + daily_counts=daily_counts, + ) + + async def get_yesterday_summary_for_agent(self, agent_id: UUID, user_id: str) -> YesterdaySummary: + """ + Get yesterday's activity summary for an agent's sessions. + + Returns the message count and optionally the last message content. + """ + # Calculate yesterday's date range + today = datetime.now(timezone.utc).date() + yesterday = today - timedelta(days=1) + yesterday_start = datetime.combine(yesterday, datetime.min.time()).replace(tzinfo=timezone.utc) + yesterday_end = datetime.combine(today, datetime.min.time()).replace(tzinfo=timezone.utc) + + # Count messages from yesterday + count_stmt = ( + select(func.count(Message.id)) # type: ignore[arg-type] + .select_from(Message) + .join(Topic, col(Message.topic_id) == col(Topic.id)) + .join(Session, col(Topic.session_id) == col(Session.id)) + .where( + and_( + col(Session.agent_id) == agent_id, + col(Session.user_id) == user_id, + col(Message.created_at) >= yesterday_start, + col(Message.created_at) < yesterday_end, + ) + ) + ) + count_result = await self.db.exec(count_stmt) + message_count = count_result.one_or_none() or 0 + + # Get the last assistant message from yesterday (if any) + last_message_stmt = ( + select(Message.content) + .select_from(Message) + .join(Topic, col(Message.topic_id) == col(Topic.id)) + .join(Session, col(Topic.session_id) == col(Session.id)) + .where( + and_( + col(Session.agent_id) == agent_id, + col(Session.user_id) == user_id, + col(Message.created_at) >= yesterday_start, + col(Message.created_at) < yesterday_end, + col(Message.role) == "assistant", + ) + ) + .order_by(col(Message.created_at).desc()) + .limit(1) + ) + last_result = await self.db.exec(last_message_stmt) + last_message = last_result.one_or_none() + + # Truncate long messages for preview + last_content = None + if last_message: + content = str(last_message) + last_content = content[:200] + "..." if len(content) > 200 else content + + return YesterdaySummary( + agent_id=agent_id, + message_count=int(message_count), + last_message_content=last_content, + ) diff --git a/service/migrations/versions/efa73edfbb15_add_session_daily_status.py b/service/migrations/versions/efa73edfbb15_add_session_daily_status.py new file mode 100644 index 00000000..63dbd45f --- /dev/null +++ b/service/migrations/versions/efa73edfbb15_add_session_daily_status.py @@ -0,0 +1,34 @@ +"""add session daily status + +Revision ID: efa73edfbb15 +Revises: e427ec7ce799 +Create Date: 2026-01-15 02:08:04.425794 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = "efa73edfbb15" +down_revision: Union[str, Sequence[str], None] = "e427ec7ce799" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("session", sa.Column("avatar", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("session", "avatar") + # ### end Alembic commands ### diff --git a/web/src/app/chat/SpatialWorkspace.tsx b/web/src/app/chat/SpatialWorkspace.tsx index d24357c1..64850804 100644 --- a/web/src/app/chat/SpatialWorkspace.tsx +++ b/web/src/app/chat/SpatialWorkspace.tsx @@ -10,6 +10,7 @@ import "@xyflow/react/dist/style.css"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import AddAgentModal from "@/components/modals/AddAgentModal"; +import EditAgentModal from "@/components/modals/EditAgentModal"; import { useXyzen } from "@/store"; import type { AgentSpatialLayout, @@ -39,6 +40,7 @@ import type { const agentToFlowNode = ( agent: AgentWithLayout, stats?: AgentStatsAggregated, + sessionId?: string, ): AgentFlowNode => { const statsDisplay: AgentStatsDisplay | undefined = stats ? { @@ -54,6 +56,8 @@ const agentToFlowNode = ( type: "agent", position: agent.spatial_layout.position, data: { + agentId: agent.id, + sessionId: sessionId, name: agent.name, role: (agent.description?.split("\n")[0] || "Agent") as string, desc: agent.description || "", @@ -71,13 +75,21 @@ const agentToFlowNode = ( }; function InnerWorkspace() { - const { agents, fetchAgents, updateAgentLayout, agentStats } = useXyzen(); + const { + agents, + fetchAgents, + updateAgentLayout, + updateAgentAvatar, + agentStats, + sessionIdByAgentId, + } = useXyzen(); const [nodes, setNodes, onNodesChange] = useNodesState([]); const [edges, , onEdgesChange] = useEdgesState([]); const [focusedAgentId, setFocusedAgentId] = useState(null); const [saveStatus, setSaveStatus] = useState("idle"); const [isAddModalOpen, setAddModalOpen] = useState(false); + const [editingAgentId, setEditingAgentId] = useState(null); const containerRef = useRef(null); const [prevViewport, setPrevViewport] = useState<{ x: number; @@ -158,11 +170,12 @@ function InnerWorkspace() { if (agents.length > 0) { const flowNodes = agents.map((agent) => { const stats = agentStats[agent.id]; - return agentToFlowNode(agent, stats); + const sessionId = sessionIdByAgentId[agent.id]; + return agentToFlowNode(agent, stats, sessionId); }); setNodes(flowNodes); } - }, [agents, agentStats, setNodes]); + }, [agents, agentStats, sessionIdByAgentId, setNodes]); useEffect(() => { if (didInitialFitViewRef.current) return; @@ -252,6 +265,29 @@ function InnerWorkspace() { setPrevViewport(null); }, [prevViewport, setViewport]); + // Handle layout changes from AgentNode (e.g., resize) + const handleLayoutChange = useCallback( + (id: string, layout: AgentSpatialLayout) => { + scheduleSave(id, layout); + }, + [scheduleSave], + ); + + // Handle avatar changes from AgentNode + const handleAvatarChange = useCallback( + (id: string, avatarUrl: string) => { + updateAgentAvatar(id, avatarUrl).catch((err) => + console.error("Failed to update avatar:", err), + ); + }, + [updateAgentAvatar], + ); + + // Handle opening agent settings (EditAgentModal) + const handleOpenAgentSettings = useCallback((agentId: string) => { + setEditingAgentId(agentId); + }, []); + // Inject handleFocus into node data const nodeTypes = useMemo( () => ({ @@ -267,10 +303,20 @@ function InnerWorkspace() { data: { ...n.data, onFocus: handleFocus, + onLayoutChange: handleLayoutChange, + onAvatarChange: handleAvatarChange, + onOpenAgentSettings: handleOpenAgentSettings, isFocused: n.id === focusedAgentId, }, })); - }, [nodes, handleFocus, focusedAgentId]); + }, [ + nodes, + handleFocus, + handleLayoutChange, + handleAvatarChange, + handleOpenAgentSettings, + focusedAgentId, + ]); const handleNodeDragStop = useCallback( (_: unknown, draggedNode: AgentFlowNode) => { @@ -426,6 +472,17 @@ function InnerWorkspace() { isOpen={isAddModalOpen} onClose={() => setAddModalOpen(false)} /> + + {/* Edit Agent Modal */} + setEditingAgentId(null)} + agent={ + editingAgentId + ? (agents.find((a) => a.id === editingAgentId) ?? null) + : null + } + />
); } diff --git a/web/src/app/chat/spatial/AgentNode.tsx b/web/src/app/chat/spatial/AgentNode.tsx index 367e5b33..b8cafd39 100644 --- a/web/src/app/chat/spatial/AgentNode.tsx +++ b/web/src/app/chat/spatial/AgentNode.tsx @@ -1,4 +1,4 @@ -import { Modal } from "@/components/animate-ui/components/animate/modal"; +import SessionSettingsModal from "@/components/modals/SessionSettingsModal"; import { cn } from "@/lib/utils"; import { ChatBubbleLeftRightIcon, @@ -8,7 +8,12 @@ import { import { useReactFlow } from "@xyflow/react"; import { motion } from "framer-motion"; import { useState } from "react"; -import type { AgentFlowNodeProps, AgentStatsDisplay } from "./types"; +import type { + AgentFlowNodeProps, + AgentStatsDisplay, + DailyActivityData, + YesterdaySummaryData, +} from "./types"; // Helper to calc size // Base unit: 1x1 = 200x160. Gap = 16. @@ -37,69 +42,201 @@ const formatTokenCount = (count: number): string => { return count.toString(); }; +// Mini activity chart component for 7-day visualization +function ActivityChart({ + data, + className, +}: { + data: DailyActivityData[]; + className?: string; +}) { + const maxCount = Math.max(...data.map((d) => d.count), 1); + + return ( +
+ {data.map((day, i) => { + const heightPercent = (day.count / maxCount) * 100; + const isToday = i === data.length - 1; + return ( +
+ ); + })} +
+ ); +} + +// Yesterday summary bubble +function YesterdayBubble({ + summary, + className, +}: { + summary?: YesterdaySummaryData; + className?: string; +}) { + if (!summary) return null; + + const hasActivity = summary.messageCount > 0; + + return ( +
+ {hasActivity ? ( + <> + + 昨日聊了 {summary.messageCount} 条 + + {summary.lastMessagePreview && ( +

+ "{summary.lastMessagePreview}" +

+ )} + + ) : ( + 你昨天没有和我聊天哟 😢 + )} +
+ ); +} + // Stats display component with responsive layout function StatsDisplay({ stats, gridW, gridH, + dailyActivity, + yesterdaySummary, }: { stats?: AgentStatsDisplay; gridW: number; gridH: number; + dailyActivity?: DailyActivityData[]; + yesterdaySummary?: YesterdaySummaryData; }) { if (!stats) return null; const totalTokens = stats.inputTokens + stats.outputTokens; - const hasActivity = stats.messageCount > 0 || stats.topicCount > 0; const area = gridW * gridH; - // Compact 1x1: Only show message count as badge + // 1x1: Compact stats with all key metrics if (area === 1) { - if (!hasActivity) return null; return ( -
- - {stats.messageCount} -
- ); - } - - // 2x1 horizontal: Compact inline stats - if (gridW >= 2 && gridH === 1) { - return ( -
-
- - {stats.messageCount} -
-
- - {stats.topicCount} +
+ {/* Main stats row */} +
+
+ + {stats.messageCount} +
+
+ + {stats.topicCount} +
+ {/* Token mini bar */} {totalTokens > 0 && ( -
- {formatTokenCount(totalTokens)} tokens +
+
+
+
+
+
+ {formatTokenCount(totalTokens)} +
)}
); } - // 1x2 vertical: Stacked stats - if (gridW === 1 && gridH >= 2) { + // 2x1 horizontal or 1x2 vertical: Full stats with token bar (same layout as 2x2 but compact) + if (area === 2) { + const isHorizontal = gridW >= 2; return ( -
-
- - {stats.messageCount} messages -
-
- - {stats.topicCount} topics +
+ {/* Stats */} +
+
+ + + {stats.messageCount} + + msgs +
+
+ + + {stats.topicCount} + + topics +
+ {/* Token bar */} {totalTokens > 0 && ( -
- {formatTokenCount(totalTokens)} tokens +
+
+
+
+
+
+ {isHorizontal ? ( + formatTokenCount(totalTokens) + ) : ( + <> + ↓{formatTokenCount(stats.inputTokens)} + ↑{formatTokenCount(stats.outputTokens)} + + )} +
)}
@@ -110,7 +247,7 @@ function StatsDisplay({ return (
{/* Stats row */} -
+
@@ -127,6 +264,21 @@ function StatsDisplay({
+ {/* Yesterday summary bubble (for 2x2+) */} + {area >= 4 && ( + + )} + + {/* Daily activity chart (for 2x2+) */} + {area >= 4 && dailyActivity && dailyActivity.length > 0 && ( +
+
+ 7 Day Activity +
+ +
+ )} + {/* Token usage bar */} {totalTokens > 0 && (
@@ -156,78 +308,6 @@ function StatsDisplay({
)} - - {/* Activity visualization for larger sizes */} - {area >= 6 && hasActivity && ( -
- {Array.from({ length: Math.min(stats.topicCount + 2, 8) }).map( - (_, i) => ( -
- ), - )} -
- )} -
- ); -} - -function GridResizer({ - currentW = 1, - currentH = 1, - onResize, -}: { - currentW?: number; - currentH?: number; - onResize: (w: number, h: number) => void; -}) { - const [hover, setHover] = useState<{ w: number; h: number } | null>(null); - - return ( -
-
- Adjust the grid size of this agent widget. -
-
-
setHover(null)} - > - {Array.from({ length: 9 }).map((_, i) => { - const x = (i % 3) + 1; - const y = Math.floor(i / 3) + 1; - const isHovered = hover && x <= hover.w && y <= hover.h; - const isSelected = !hover && x <= currentW && y <= currentH; - - return ( -
setHover({ w: x, h: y })} - onClick={(e) => { - e.preventDefault(); - e.stopPropagation(); - onResize(x, y); - }} - /> - ); - })} -
-
-
- {hover ? `${hover.w} x ${hover.h}` : `${currentW} x ${currentH}`} -
); } @@ -241,27 +321,62 @@ export function AgentNode({ id, data, selected }: AgentFlowNodeProps) { const style = getSizeStyle(data.gridSize?.w, data.gridSize?.h, data.size); + const handleResize = (w: number, h: number) => { + const newSize = w * h > 3 ? "large" : w * h > 1 ? "medium" : "small"; + + // Update ReactFlow node data + updateNodeData(id, { + gridSize: { w, h }, + size: newSize, + }); + + // Notify parent to persist the layout change + if (data.onLayoutChange) { + data.onLayoutChange(id, { + position: data.position, + gridSize: { w, h }, + size: newSize, + }); + } + }; + + const handleAvatarChange = (avatarUrl: string) => { + console.log("[AgentNode] handleAvatarChange called:", { id, avatarUrl }); + // Update local node data + updateNodeData(id, { avatar: avatarUrl }); + + // Notify parent to persist avatar change + if (data.onAvatarChange) { + console.log("[AgentNode] Calling data.onAvatarChange"); + data.onAvatarChange(id, avatarUrl); + } else { + console.warn("[AgentNode] data.onAvatarChange is not defined!"); + } + }; + + const handleOpenAgentSettings = () => { + setIsSettingsOpen(false); + if (data.onOpenAgentSettings && data.agentId) { + data.onOpenAgentSettings(data.agentId); + } + }; + return ( <> - setIsSettingsOpen(false)} - title="Widget Settings" - maxWidth="max-w-xs" - > - { - updateNodeData(id, { - gridSize: { w, h }, - size: w * h > 3 ? "large" : w * h > 1 ? "medium" : "small", - }); - // Optional: Close modal after selection if desired, or keep open - // setIsSettingsOpen(false); - }} - /> - + sessionId={data.sessionId || id} + agentId={data.agentId || id} + agentName={data.name} + currentAvatar={data.avatar} + currentGridSize={data.gridSize || { w: currentW, h: currentH }} + onAvatarChange={handleAvatarChange} + onGridSizeChange={handleResize} + onOpenAgentSettings={ + data.onOpenAgentSettings ? handleOpenAgentSettings : undefined + } + />
-
+
{data.name}
@@ -345,6 +460,8 @@ export function AgentNode({ id, data, selected }: AgentFlowNodeProps) { stats={data.stats} gridW={currentW} gridH={currentH} + dailyActivity={data.dailyActivity} + yesterdaySummary={data.yesterdaySummary} />
diff --git a/web/src/app/chat/spatial/SaveStatusIndicator.tsx b/web/src/app/chat/spatial/SaveStatusIndicator.tsx index 18b9e15b..b0e83136 100644 --- a/web/src/app/chat/spatial/SaveStatusIndicator.tsx +++ b/web/src/app/chat/spatial/SaveStatusIndicator.tsx @@ -28,7 +28,7 @@ export function SaveStatusIndicator({ animate={{ opacity: 1, y: 0 }} exit={{ opacity: 0, y: -10 }} transition={{ duration: 0.2 }} - className="fixed top-4 right-4 z-50" + className="fixed top-20 right-4 z-50" > {status === "saving" && (
diff --git a/web/src/app/chat/spatial/types.ts b/web/src/app/chat/spatial/types.ts index 88a292a8..b54e1fb0 100644 --- a/web/src/app/chat/spatial/types.ts +++ b/web/src/app/chat/spatial/types.ts @@ -1,3 +1,4 @@ +import type { AgentSpatialLayout } from "@/types/agents"; import type { Node, NodeProps } from "@xyflow/react"; export type XYPosition = { x: number; y: number }; @@ -14,6 +15,22 @@ export interface AgentStatsDisplay { outputTokens: number; } +/** + * Daily message count for activity chart. + */ +export interface DailyActivityData { + date: string; + count: number; +} + +/** + * Yesterday's summary data for agent. + */ +export interface YesterdaySummaryData { + messageCount: number; + lastMessagePreview?: string | null; +} + /** * Persistable agent widget data (no functions). * @@ -21,6 +38,10 @@ export interface AgentStatsDisplay { * so it can be persisted/serialized without the full Node shape. */ export interface AgentData { + /** Agent ID (used for API calls) */ + agentId: string; + /** Session ID (used for Session API calls) */ + sessionId?: string; name: string; role: string; desc: string; @@ -31,11 +52,18 @@ export interface AgentData { position: XYPosition; /** Stats for display visualization */ stats?: AgentStatsDisplay; + /** Daily activity for chart (last 7 days) */ + dailyActivity?: DailyActivityData[]; + /** Yesterday's summary */ + yesterdaySummary?: YesterdaySummaryData; } /** Runtime-only fields injected by the workspace. */ export interface AgentNodeRuntimeData { onFocus: (id: string) => void; + onLayoutChange?: (id: string, layout: AgentSpatialLayout) => void; + onAvatarChange?: (id: string, avatarUrl: string) => void; + onOpenAgentSettings?: (agentId: string) => void; isFocused?: boolean; } diff --git a/web/src/components/modals/SessionSettingsModal.tsx b/web/src/components/modals/SessionSettingsModal.tsx new file mode 100644 index 00000000..3da90cb5 --- /dev/null +++ b/web/src/components/modals/SessionSettingsModal.tsx @@ -0,0 +1,390 @@ +import { Modal } from "@/components/animate-ui/components/animate/modal"; +import { cn } from "@/lib/utils"; +import { useXyzen } from "@/store"; +import { + ArrowsPointingOutIcon, + CheckIcon, + Cog6ToothIcon, + SparklesIcon, +} from "@heroicons/react/24/outline"; +import React, { useCallback, useMemo, useState } from "react"; +import { useTranslation } from "react-i18next"; + +// Preset DiceBear styles and seeds for avatar selection +const DICEBEAR_STYLES = [ + "adventurer", + "avataaars", + "bottts", + "fun-emoji", + "lorelei", + "micah", + "miniavs", + "notionists", + "open-peeps", + "personas", + "pixel-art", + "shapes", + "thumbs", +] as const; + +/** + * Build avatar URL - uses backend proxy if available for better China access. + * Falls back to direct DiceBear API if backend URL is not configured. + */ +const buildAvatarUrl = ( + style: string, + seed: string, + backendUrl?: string, +): string => { + // Use backend proxy for better accessibility in China + if (backendUrl) { + return `${backendUrl}/xyzen/api/v1/avatar/${style}/svg?seed=${encodeURIComponent(seed)}`; + } + // Fallback to direct DiceBear API + return `https://api.dicebear.com/9.x/${style}/svg?seed=${seed}`; +}; + +// Generate a set of preset avatars using different styles and random seeds +const generatePresetAvatars = (backendUrl?: string) => { + const avatars: { url: string; seed: string; style: string }[] = []; + + // Generate 3 avatars for each style + DICEBEAR_STYLES.forEach((style) => { + for (let i = 0; i < 3; i++) { + const seed = `${style}_${i}_preset`; + avatars.push({ + url: buildAvatarUrl(style, seed, backendUrl), + seed, + style, + }); + } + }); + + return avatars; +}; + +interface GridResizerProps { + currentW?: number; + currentH?: number; + onResize: (w: number, h: number) => void; +} + +function GridResizer({ + currentW = 1, + currentH = 1, + onResize, +}: GridResizerProps) { + const [hover, setHover] = useState<{ w: number; h: number } | null>(null); + + return ( +
+
+
setHover(null)} + > + {Array.from({ length: 9 }).map((_, i) => { + const x = (i % 3) + 1; + const y = Math.floor(i / 3) + 1; + const isHovered = hover && x <= hover.w && y <= hover.h; + const isSelected = !hover && x <= currentW && y <= currentH; + + return ( +
setHover({ w: x, h: y })} + onClick={(e) => { + e.preventDefault(); + e.stopPropagation(); + onResize(x, y); + }} + /> + ); + })} +
+
+
+ {hover ? `${hover.w} × ${hover.h}` : `${currentW} × ${currentH}`} +
+
+ ); +} + +interface AvatarSelectorProps { + currentAvatar?: string; + onSelect: (avatarUrl: string) => void; + backendUrl?: string; +} + +function AvatarSelector({ + currentAvatar, + onSelect, + backendUrl, +}: AvatarSelectorProps) { + const [selectedStyle, setSelectedStyle] = + useState<(typeof DICEBEAR_STYLES)[number]>("avataaars"); + const [customSeed, setCustomSeed] = useState(""); + + // Generate preset avatars with backend URL + const presetAvatars = useMemo( + () => generatePresetAvatars(backendUrl), + [backendUrl], + ); + + // Filter avatars by selected style + const filteredAvatars = presetAvatars.filter( + (a) => a.style === selectedStyle, + ); + + // Generate random avatar + const generateRandom = useCallback(() => { + const seed = Math.random().toString(36).slice(2, 10); + const url = buildAvatarUrl(selectedStyle, seed, backendUrl); + console.log("[AvatarSelector] generateRandom:", { + style: selectedStyle, + seed, + url, + }); + onSelect(url); + }, [selectedStyle, onSelect, backendUrl]); + + // Generate from custom seed + const generateFromSeed = useCallback(() => { + if (!customSeed.trim()) return; + const url = buildAvatarUrl(selectedStyle, customSeed.trim(), backendUrl); + console.log("[AvatarSelector] generateFromSeed:", { + style: selectedStyle, + seed: customSeed, + url, + }); + onSelect(url); + }, [selectedStyle, customSeed, onSelect, backendUrl]); + + // Handle preset avatar selection + const handlePresetSelect = useCallback( + (avatarUrl: string) => { + console.log("[AvatarSelector] handlePresetSelect:", { avatarUrl }); + onSelect(avatarUrl); + }, + [onSelect], + ); + + return ( +
+ {/* Current Avatar Preview */} +
+
+ Current avatar + {currentAvatar && ( +
+ +
+ )} +
+
+ + {/* Style Selector */} +
+ +
+ {DICEBEAR_STYLES.map((style) => ( + + ))} +
+
+ + {/* Preset Avatars Grid */} +
+ +
+ {filteredAvatars.map((avatar, i) => ( + + ))} +
+
+ + {/* Random & Custom Seed */} +
+ +
+ setCustomSeed(e.target.value)} + placeholder="Custom seed..." + className="flex-1 px-2 py-1.5 text-xs border border-neutral-200 dark:border-neutral-700 rounded-md bg-white dark:bg-neutral-800 text-neutral-900 dark:text-neutral-100 placeholder-neutral-400" + /> + +
+
+
+ ); +} + +interface SessionSettingsModalProps { + isOpen: boolean; + onClose: () => void; + sessionId: string; + agentId: string; + agentName: string; + currentAvatar?: string; + currentGridSize?: { w: number; h: number }; + onAvatarChange: (avatarUrl: string) => void; + onGridSizeChange: (w: number, h: number) => void; + onOpenAgentSettings?: () => void; +} + +const SessionSettingsModal: React.FC = ({ + isOpen, + onClose, + agentName, + currentAvatar, + currentGridSize, + onAvatarChange, + onGridSizeChange, + onOpenAgentSettings, +}) => { + const { t } = useTranslation(); + const backendUrl = useXyzen((state) => state.backendUrl); + const [activeSection, setActiveSection] = useState<"avatar" | "size">( + "avatar", + ); + + return ( + +
+ {/* Section Tabs */} +
+ + +
+ + {/* Content */} + {activeSection === "avatar" && ( + + )} + + {activeSection === "size" && ( +
+

+ Adjust the widget size in the spatial workspace +

+ +
+ )} + + {/* Open Agent Settings */} + {onOpenAgentSettings && ( +
+ +
+ )} +
+
+ ); +}; + +export default SessionSettingsModal; diff --git a/web/src/service/sessionService.ts b/web/src/service/sessionService.ts index db593436..eee528f6 100644 --- a/web/src/service/sessionService.ts +++ b/web/src/service/sessionService.ts @@ -11,6 +11,7 @@ export interface SessionCreate { model?: string; model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; + avatar?: string; spatial_layout?: AgentSpatialLayout; } @@ -22,6 +23,7 @@ export interface SessionUpdate { model?: string; model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; + avatar?: string; spatial_layout?: AgentSpatialLayout; } @@ -36,6 +38,7 @@ export interface SessionRead { model?: string; model_tier?: "ultra" | "pro" | "standard" | "lite"; google_search_enabled?: boolean; + avatar?: string; spatial_layout?: AgentSpatialLayout; created_at: string; updated_at: string; diff --git a/web/src/store/slices/agentSlice.ts b/web/src/store/slices/agentSlice.ts index 732ebf84..fd2228d9 100644 --- a/web/src/store/slices/agentSlice.ts +++ b/web/src/store/slices/agentSlice.ts @@ -42,6 +42,7 @@ export interface AgentSlice { agentId: string, layout: AgentSpatialLayout, ) => Promise; + updateAgentAvatar: (agentId: string, avatarUrl: string) => Promise; updateAgentProvider: ( agentId: string, providerId: string | null, @@ -135,10 +136,11 @@ export const createAgentSlice: StateCreator< const rawAgents: Agent[] = await response.json(); - // Fetch sessions for each agent to get spatial_layout - // Build session mapping and extract layouts + // Fetch sessions for each agent to get spatial_layout and avatar + // Build session mapping and extract layouts/avatars const sessionMap: Record = {}; const layoutMap: Record = {}; + const avatarMap: Record = {}; await Promise.all( rawAgents.map(async (agent) => { @@ -148,6 +150,9 @@ export const createAgentSlice: StateCreator< if (session.spatial_layout) { layoutMap[agent.id] = session.spatial_layout; } + if (session.avatar) { + avatarMap[agent.id] = session.avatar; + } } catch { // Session doesn't exist yet - will be created when user starts chat console.debug(`No session found for agent ${agent.id}`); @@ -155,11 +160,12 @@ export const createAgentSlice: StateCreator< }), ); - // Enrich agents with layout from session or default + // Enrich agents with layout and avatar from session or default const agents: AgentWithLayout[] = rawAgents.map((agent, index) => ({ ...agent, spatial_layout: layoutMap[agent.id] ?? defaultSpatialLayoutForIndex(index), + avatar: avatarMap[agent.id] ?? agent.avatar, })); set({ @@ -374,6 +380,81 @@ export const createAgentSlice: StateCreator< } }, + updateAgentAvatar: async (agentId, avatarUrl) => { + console.log("[agentSlice] updateAgentAvatar called:", { + agentId, + avatarUrl, + }); + try { + // Get the session ID for this agent + let sessionId = get().sessionIdByAgentId[agentId]; + console.log( + "[agentSlice] Current sessionIdByAgentId:", + get().sessionIdByAgentId, + ); + console.log("[agentSlice] Found sessionId:", sessionId); + + if (!sessionId) { + // Try to fetch the session if not cached + try { + console.log("[agentSlice] Fetching session for agent:", agentId); + const session = await sessionService.getSessionByAgent(agentId); + sessionId = session.id; + console.log("[agentSlice] Got session from API:", session.id); + // Cache it + set((state) => { + state.sessionIdByAgentId[agentId] = sessionId; + }); + } catch (fetchError) { + // Session doesn't exist - create one first + console.warn( + `[agentSlice] No session found for agent ${agentId}, creating one...`, + fetchError, + ); + const agent = get().agents.find((a) => a.id === agentId); + const newSession = await sessionService.createSession({ + name: agent?.name ?? "Agent Session", + agent_id: agentId, + avatar: avatarUrl, + }); + console.log("[agentSlice] Created new session:", newSession.id); + sessionId = newSession.id; + set((state) => { + state.sessionIdByAgentId[agentId] = sessionId; + }); + + // Update local state + set((state) => { + const agentData = state.agents.find((a) => a.id === agentId); + if (agentData) { + agentData.avatar = avatarUrl; + } + }); + return; + } + } + + // Update the session's avatar via Session API + console.log("[agentSlice] Updating session avatar:", { + sessionId, + avatarUrl, + }); + await sessionService.updateSession(sessionId, { avatar: avatarUrl }); + console.log("[agentSlice] Session avatar updated successfully"); + + // Update local state optimistically + set((state) => { + const agent = state.agents.find((a) => a.id === agentId); + if (agent) { + agent.avatar = avatarUrl; + } + }); + } catch (error) { + console.error("[agentSlice] Failed to update agent avatar:", error); + throw error; + } + }, + updateAgentProvider: async (agentId, providerId) => { try { const response = await fetch( diff --git a/web/src/types/agents.ts b/web/src/types/agents.ts index dfe54bc4..1ccc29e7 100644 --- a/web/src/types/agents.ts +++ b/web/src/types/agents.ts @@ -24,6 +24,32 @@ export interface AgentStatsAggregated { output_tokens: number; } +/** + * Daily message count for activity visualization. + */ +export interface DailyMessageCount { + date: string; // ISO date string (YYYY-MM-DD) + message_count: number; +} + +/** + * Daily activity stats for an agent (last N days). + */ +export interface DailyStatsResponse { + agent_id: string; + daily_counts: DailyMessageCount[]; +} + +/** + * Yesterday's activity summary for a session/agent. + */ +export interface YesterdaySummary { + agent_id: string; + message_count: number; + last_message_content?: string | null; + summary?: string | null; +} + /** * Calculate the visual scale multiplier based on message count. * Uses a logarithmic curve for diminishing returns: From 456a387dc7babf58d80a1b5f6b675a09f7514b6f Mon Sep 17 00:00:00 2001 From: Harvey Date: Thu, 15 Jan 2026 02:34:10 +0800 Subject: [PATCH 08/11] feat: Remove debug logging and enhance session update flow --- service/app/api/v1/sessions.py | 10 +-- web/src/app/chat/SpatialWorkspace.tsx | 40 +++++++++- web/src/app/chat/spatial/AgentNode.tsx | 4 - .../modals/SessionSettingsModal.tsx | 11 --- web/src/store/slices/agentSlice.ts | 78 ++++++++++++++----- 5 files changed, 97 insertions(+), 46 deletions(-) diff --git a/service/app/api/v1/sessions.py b/service/app/api/v1/sessions.py index 990621c2..f4ef323a 100644 --- a/service/app/api/v1/sessions.py +++ b/service/app/api/v1/sessions.py @@ -175,15 +175,7 @@ async def update_session( Returns: SessionRead: The updated session """ - import logging - - logger = logging.getLogger(__name__) - logger.info( - f"[DEBUG] update_session called: session_id={session_id}, data={session_data.model_dump(exclude_unset=True)}" - ) try: - result = await SessionService(db).update_session(session_id, session_data, user) - logger.info(f"[DEBUG] update_session success: avatar={result.avatar}") - return result + return await SessionService(db).update_session(session_id, session_data, user) except ErrCodeError as e: raise handle_auth_error(e) diff --git a/web/src/app/chat/SpatialWorkspace.tsx b/web/src/app/chat/SpatialWorkspace.tsx index 64850804..bec087b4 100644 --- a/web/src/app/chat/SpatialWorkspace.tsx +++ b/web/src/app/chat/SpatialWorkspace.tsx @@ -29,7 +29,9 @@ import type { AgentData, AgentFlowNode, AgentStatsDisplay, + DailyActivityData, FlowAgentNodeData, + YesterdaySummaryData, } from "./spatial/types"; /** @@ -41,6 +43,8 @@ const agentToFlowNode = ( agent: AgentWithLayout, stats?: AgentStatsAggregated, sessionId?: string, + dailyActivity?: DailyActivityData[], + yesterdaySummary?: YesterdaySummaryData, ): AgentFlowNode => { const statsDisplay: AgentStatsDisplay | undefined = stats ? { @@ -69,6 +73,8 @@ const agentToFlowNode = ( gridSize: agent.spatial_layout.gridSize, position: agent.spatial_layout.position, stats: statsDisplay, + dailyActivity, + yesterdaySummary, onFocus: () => {}, } as FlowAgentNodeData, }; @@ -82,6 +88,8 @@ function InnerWorkspace() { updateAgentAvatar, agentStats, sessionIdByAgentId, + dailyActivity, + yesterdaySummary, } = useXyzen(); const [nodes, setNodes, onNodesChange] = useNodesState([]); @@ -171,11 +179,39 @@ function InnerWorkspace() { const flowNodes = agents.map((agent) => { const stats = agentStats[agent.id]; const sessionId = sessionIdByAgentId[agent.id]; - return agentToFlowNode(agent, stats, sessionId); + // Convert daily activity to the format expected by AgentNode + const agentDailyActivity = dailyActivity[agent.id]?.daily_counts?.map( + (d) => ({ + date: d.date, + count: d.message_count, + }), + ); + // Convert yesterday summary + const agentYesterdaySummary = yesterdaySummary[agent.id] + ? { + messageCount: yesterdaySummary[agent.id].message_count, + lastMessagePreview: + yesterdaySummary[agent.id].last_message_content, + } + : undefined; + return agentToFlowNode( + agent, + stats, + sessionId, + agentDailyActivity, + agentYesterdaySummary, + ); }); setNodes(flowNodes); } - }, [agents, agentStats, sessionIdByAgentId, setNodes]); + }, [ + agents, + agentStats, + sessionIdByAgentId, + dailyActivity, + yesterdaySummary, + setNodes, + ]); useEffect(() => { if (didInitialFitViewRef.current) return; diff --git a/web/src/app/chat/spatial/AgentNode.tsx b/web/src/app/chat/spatial/AgentNode.tsx index b8cafd39..519f07d9 100644 --- a/web/src/app/chat/spatial/AgentNode.tsx +++ b/web/src/app/chat/spatial/AgentNode.tsx @@ -341,16 +341,12 @@ export function AgentNode({ id, data, selected }: AgentFlowNodeProps) { }; const handleAvatarChange = (avatarUrl: string) => { - console.log("[AgentNode] handleAvatarChange called:", { id, avatarUrl }); // Update local node data updateNodeData(id, { avatar: avatarUrl }); // Notify parent to persist avatar change if (data.onAvatarChange) { - console.log("[AgentNode] Calling data.onAvatarChange"); data.onAvatarChange(id, avatarUrl); - } else { - console.warn("[AgentNode] data.onAvatarChange is not defined!"); } }; diff --git a/web/src/components/modals/SessionSettingsModal.tsx b/web/src/components/modals/SessionSettingsModal.tsx index 3da90cb5..bc1869b1 100644 --- a/web/src/components/modals/SessionSettingsModal.tsx +++ b/web/src/components/modals/SessionSettingsModal.tsx @@ -146,11 +146,6 @@ function AvatarSelector({ const generateRandom = useCallback(() => { const seed = Math.random().toString(36).slice(2, 10); const url = buildAvatarUrl(selectedStyle, seed, backendUrl); - console.log("[AvatarSelector] generateRandom:", { - style: selectedStyle, - seed, - url, - }); onSelect(url); }, [selectedStyle, onSelect, backendUrl]); @@ -158,18 +153,12 @@ function AvatarSelector({ const generateFromSeed = useCallback(() => { if (!customSeed.trim()) return; const url = buildAvatarUrl(selectedStyle, customSeed.trim(), backendUrl); - console.log("[AvatarSelector] generateFromSeed:", { - style: selectedStyle, - seed: customSeed, - url, - }); onSelect(url); }, [selectedStyle, customSeed, onSelect, backendUrl]); // Handle preset avatar selection const handlePresetSelect = useCallback( (avatarUrl: string) => { - console.log("[AvatarSelector] handlePresetSelect:", { avatarUrl }); onSelect(avatarUrl); }, [onSelect], diff --git a/web/src/store/slices/agentSlice.ts b/web/src/store/slices/agentSlice.ts index fd2228d9..9348913a 100644 --- a/web/src/store/slices/agentSlice.ts +++ b/web/src/store/slices/agentSlice.ts @@ -5,7 +5,9 @@ import type { AgentSpatialLayout, AgentStatsAggregated, AgentWithLayout, + DailyStatsResponse, SystemAgentTemplate, + YesterdaySummary, } from "@/types/agents"; import type { StateCreator } from "zustand"; import type { XyzenState } from "../types"; @@ -22,6 +24,11 @@ export interface AgentSlice { agentStats: Record; agentStatsLoading: boolean; + // Daily activity data for charts (last 7 days) + dailyActivity: Record; + // Yesterday summary for each agent + yesterdaySummary: Record; + // System agent templates systemAgentTemplates: SystemAgentTemplate[]; templatesLoading: boolean; @@ -29,6 +36,7 @@ export interface AgentSlice { fetchAgents: () => Promise; fetchAgentStats: () => Promise; + fetchDailyActivity: () => Promise; incrementLocalAgentMessageCount: (agentId: string) => void; isCreatingAgent: boolean; @@ -94,6 +102,10 @@ export const createAgentSlice: StateCreator< agentStats: {}, agentStatsLoading: false, + // Daily activity and yesterday summary + dailyActivity: {}, + yesterdaySummary: {}, + // System agent templates state systemAgentTemplates: [], templatesLoading: false, @@ -199,6 +211,9 @@ export const createAgentSlice: StateCreator< const stats: Record = await response.json(); set({ agentStats: stats, agentStatsLoading: false }); + + // Also fetch daily activity for visualization + get().fetchDailyActivity(); } catch (error) { console.error("Failed to fetch agent stats:", error); set({ agentStatsLoading: false }); @@ -206,6 +221,47 @@ export const createAgentSlice: StateCreator< } }, + fetchDailyActivity: async () => { + const agents = get().agents; + if (agents.length === 0) return; + + const backendUrl = get().backendUrl; + const dailyActivity: Record = {}; + const yesterdaySummary: Record = {}; + + // Fetch daily stats and yesterday summary for each agent in parallel + await Promise.all( + agents.map(async (agent) => { + try { + // Fetch daily stats (last 7 days) + const dailyResponse = await fetch( + `${backendUrl}/xyzen/api/v1/agents/stats/${agent.id}/daily`, + { headers: createAuthHeaders() }, + ); + if (dailyResponse.ok) { + dailyActivity[agent.id] = await dailyResponse.json(); + } + + // Fetch yesterday summary + const yesterdayResponse = await fetch( + `${backendUrl}/xyzen/api/v1/agents/stats/${agent.id}/yesterday`, + { headers: createAuthHeaders() }, + ); + if (yesterdayResponse.ok) { + yesterdaySummary[agent.id] = await yesterdayResponse.json(); + } + } catch (error) { + console.debug( + `Failed to fetch activity for agent ${agent.id}:`, + error, + ); + } + }), + ); + + set({ dailyActivity, yesterdaySummary }); + }, + /** * Optimistically increment the local message count for an agent. * Used for immediate UI feedback when a message is sent. @@ -381,26 +437,15 @@ export const createAgentSlice: StateCreator< }, updateAgentAvatar: async (agentId, avatarUrl) => { - console.log("[agentSlice] updateAgentAvatar called:", { - agentId, - avatarUrl, - }); try { // Get the session ID for this agent let sessionId = get().sessionIdByAgentId[agentId]; - console.log( - "[agentSlice] Current sessionIdByAgentId:", - get().sessionIdByAgentId, - ); - console.log("[agentSlice] Found sessionId:", sessionId); if (!sessionId) { // Try to fetch the session if not cached try { - console.log("[agentSlice] Fetching session for agent:", agentId); const session = await sessionService.getSessionByAgent(agentId); sessionId = session.id; - console.log("[agentSlice] Got session from API:", session.id); // Cache it set((state) => { state.sessionIdByAgentId[agentId] = sessionId; @@ -408,8 +453,7 @@ export const createAgentSlice: StateCreator< } catch (fetchError) { // Session doesn't exist - create one first console.warn( - `[agentSlice] No session found for agent ${agentId}, creating one...`, - fetchError, + `No session found for agent ${agentId}, creating one...`, ); const agent = get().agents.find((a) => a.id === agentId); const newSession = await sessionService.createSession({ @@ -417,7 +461,6 @@ export const createAgentSlice: StateCreator< agent_id: agentId, avatar: avatarUrl, }); - console.log("[agentSlice] Created new session:", newSession.id); sessionId = newSession.id; set((state) => { state.sessionIdByAgentId[agentId] = sessionId; @@ -435,12 +478,7 @@ export const createAgentSlice: StateCreator< } // Update the session's avatar via Session API - console.log("[agentSlice] Updating session avatar:", { - sessionId, - avatarUrl, - }); await sessionService.updateSession(sessionId, { avatar: avatarUrl }); - console.log("[agentSlice] Session avatar updated successfully"); // Update local state optimistically set((state) => { @@ -450,7 +488,7 @@ export const createAgentSlice: StateCreator< } }); } catch (error) { - console.error("[agentSlice] Failed to update agent avatar:", error); + console.error("Failed to update agent avatar:", error); throw error; } }, From cfd3399fbdf3f4964a91015a2a53b59003a39065 Mon Sep 17 00:00:00 2001 From: Harvey Date: Thu, 15 Jan 2026 03:42:04 +0800 Subject: [PATCH 09/11] feat: Enhance chat functionality with spatial chat frosted glass effect and agent channel activation --- README.md | 2 +- web/index.css | 56 ++++++ web/src/app/chat/SpatialWorkspace.tsx | 29 ++-- web/src/app/chat/spatial/AddAgentButton.tsx | 40 ++--- web/src/app/chat/spatial/FocusedView.tsx | 91 ++++------ web/src/components/layouts/XyzenChat.tsx | 7 +- .../layouts/components/ChatBubble.tsx | 52 ++++-- .../layouts/components/WelcomeMessage.tsx | 118 ++++++++----- web/src/store/slices/agentSlice.ts | 27 ++- web/src/store/slices/chatSlice.ts | 160 +++++++++++++++--- 10 files changed, 410 insertions(+), 172 deletions(-) diff --git a/README.md b/README.md index ea1c56d3..698aa3e4 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # Xyzen -AI Laboratory Server +Your Next Agent Capital! [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) [![Python 3.13](https://img.shields.io/badge/python-3.13-blue.svg)](https://www.python.org/downloads/release/python-3130/) diff --git a/web/index.css b/web/index.css index 2ce633be..65de1d0b 100644 --- a/web/index.css +++ b/web/index.css @@ -242,3 +242,59 @@ /* 2. 光条移动到右侧,形成扫光动画 */ transform: skewX(-45deg) translateX(400%); } + +/* Spatial Chat Frosted Glass Effect */ +.spatial-chat-frosted { + /* Make inner XyzenChat backgrounds transparent */ + & > div { + background: transparent !important; + } + + /* Header area - frosted glass */ + & .sm\:border-y, + & .border-b { + background: rgba(255, 255, 255, 0.3) !important; + border-color: rgba(255, 255, 255, 0.2) !important; + backdrop-filter: blur(8px); + } + + /* Messages container - transparent */ + & .bg-neutral-50 { + background: transparent !important; + } + + /* Input area - frosted glass */ + & .bg-white:has(textarea), + & .bg-white:has(.tiptap) { + background: rgba(255, 255, 255, 0.4) !important; + backdrop-filter: blur(8px); + } + + /* Toolbar area */ + & .border-t { + border-color: rgba(255, 255, 255, 0.2) !important; + } +} + +/* Dark mode adjustments */ +.dark .spatial-chat-frosted { + & .sm\:border-y, + & .border-b { + background: rgba(0, 0, 0, 0.3) !important; + border-color: rgba(255, 255, 255, 0.1) !important; + } + + & .dark\:bg-black { + background: transparent !important; + } + + & .bg-white:has(textarea), + & .bg-white:has(.tiptap), + & .dark\:bg-neutral-900 { + background: rgba(0, 0, 0, 0.4) !important; + } + + & .border-t { + border-color: rgba(255, 255, 255, 0.1) !important; + } +} diff --git a/web/src/app/chat/SpatialWorkspace.tsx b/web/src/app/chat/SpatialWorkspace.tsx index bec087b4..494897f4 100644 --- a/web/src/app/chat/SpatialWorkspace.tsx +++ b/web/src/app/chat/SpatialWorkspace.tsx @@ -265,29 +265,24 @@ function InnerWorkspace() { const node = getNode(id); if (!node) return; - const measuredH = node.measured?.height; - const gridSize = (node.data as FlowAgentNodeData | undefined)?.gridSize; - const fallbackH = - gridSize?.w && gridSize?.h - ? gridSize.h * 160 + (gridSize.h - 1) * 16 - : 220; - - const nodeH = measuredH ?? fallbackH; - const centerY = node.position.y + nodeH / 2; - - // Focus layout: keep a consistent left padding regardless of node size. + // Focus layout: keep a consistent left padding and top padding regardless of node size. const targetZoom = 1.35; const rect = containerRef.current?.getBoundingClientRect(); const containerW = rect?.width ?? window.innerWidth; const containerH = rect?.height ?? window.innerHeight; + // Fixed left padding (responsive but clamped) const leftPadding = Math.max(24, Math.min(64, containerW * 0.08)); const screenX = leftPadding; - const screenY = containerH * 0.25; - // Align the node's left edge to screenX. + // Fixed top padding: consistent distance from top regardless of node size + // Use similar logic to leftPadding for responsive but clamped value + const topPadding = Math.max(24, Math.min(80, containerH * 0.06)); + const screenY = topPadding; + + // Align the node's left edge to screenX and top edge to screenY. const x = -node.position.x * targetZoom + screenX; - const y = -centerY * targetZoom + screenY; + const y = -node.position.y * targetZoom + screenY; setViewport({ x, y, zoom: targetZoom }, { duration: 900 }); }, @@ -489,8 +484,10 @@ function InnerWorkspace() { {/* Save Status Indicator */} - {/* Add Agent Button */} - setAddModalOpen(true)} /> + {/* Add Agent Button - positioned at bottom right, below focus overlay */} +
+ setAddModalOpen(true)} /> +
{focusedAgent && ( diff --git a/web/src/app/chat/spatial/AddAgentButton.tsx b/web/src/app/chat/spatial/AddAgentButton.tsx index 327251b2..ae33a1f8 100644 --- a/web/src/app/chat/spatial/AddAgentButton.tsx +++ b/web/src/app/chat/spatial/AddAgentButton.tsx @@ -1,5 +1,5 @@ /** - * AddAgentButton - Floating action button to add new agents + * AddAgentButton - Small text button to add new agents */ import { PlusIcon } from "@heroicons/react/24/outline"; import { motion } from "framer-motion"; @@ -12,28 +12,24 @@ export function AddAgentButton({ onClick }: AddAgentButtonProps) { return ( - - - {/* Ripple effect on hover */} - + + 添加助手 ); } diff --git a/web/src/app/chat/spatial/FocusedView.tsx b/web/src/app/chat/spatial/FocusedView.tsx index ae51a488..0231d1f4 100644 --- a/web/src/app/chat/spatial/FocusedView.tsx +++ b/web/src/app/chat/spatial/FocusedView.tsx @@ -1,3 +1,5 @@ +import XyzenChat from "@/components/layouts/XyzenChat"; +import { useXyzen } from "@/store"; import { motion } from "framer-motion"; import { useEffect, useRef } from "react"; import { AgentData } from "./types"; @@ -18,6 +20,17 @@ export function FocusedView({ const switcherRef = useRef(null); const chatRef = useRef(null); + const { activateChannelForAgent } = useXyzen(); + + // Activate the channel for the selected agent + useEffect(() => { + if (agent.agentId) { + activateChannelForAgent(agent.agentId).catch((error) => { + console.error("Failed to activate channel for agent:", error); + }); + } + }, [agent.agentId, activateChannelForAgent]); + useEffect(() => { const onKeyDown = (e: KeyboardEvent) => { if (e.key === "Escape") onClose(); @@ -34,6 +47,15 @@ export function FocusedView({ if (chatRef.current?.contains(target)) return; if (switcherRef.current?.contains(target)) return; + // Clicking inside Radix portals (Sheet, Dialog, etc.) should not close. + // These are rendered outside our ref tree via Portal. + if ( + target.closest( + "[data-radix-portal], [data-slot='sheet-overlay'], [data-slot='sheet-content']", + ) + ) + return; + // Prevent XYFlow from starting a pan/drag on the same click, // which can override the restore viewport animation. e.preventDefault(); @@ -89,16 +111,16 @@ export function FocusedView({ /> {a.status === "busy" && ( - - + + )}
-
-
+
+
{a.name}
-
+
{a.role}
@@ -108,68 +130,17 @@ export function FocusedView({
- {/* 2. Main Chat Area */} + {/* 2. Main Chat Area - Frosted Glass Panel */} - {/* Chat Header */} -
-
-
- - Session Active - -
-
{/* Tools icons or standard window controls */}
-
- - {/* Chat Body Mock */} -
-
-
-
-

- Hello! I'm {agent.name}. I'm ready to assist you with your tasks - today. I can access your latest files and context. -

-
-
-
- - {/* Input Area */} -
-
- - -
-
+ {/* XyzenChat Component - No modifications, just wrapped */} +
); diff --git a/web/src/components/layouts/XyzenChat.tsx b/web/src/components/layouts/XyzenChat.tsx index a79219fa..d177939c 100644 --- a/web/src/components/layouts/XyzenChat.tsx +++ b/web/src/components/layouts/XyzenChat.tsx @@ -55,6 +55,11 @@ const ThemedWelcomeMessage: React.FC<{ iconType: "chat", iconColor: "indigo", category: "general", + avatar: + currentAgent.avatar || + (currentAgent.tags?.includes("default_chat") + ? "/defaults/agents/avatar1.png" + : "/defaults/agents/avatar2.png"), } : undefined } @@ -233,7 +238,7 @@ function BaseChat({ config, historyEnabled = false }: BaseChatProps) { )} {/* Messages Area */} -
+
state.confirmToolCall); const cancelToolCall = useXyzen((state) => state.cancelToolCall); const activeChatChannel = useXyzen((state) => state.activeChatChannel); + const channels = useXyzen((state) => state.channels); + const agents = useXyzen((state) => state.agents); + const user = useXyzen((state) => state.user); + + // Get current agent avatar from store + const currentChannel = activeChatChannel ? channels[activeChatChannel] : null; + const currentAgent = currentChannel?.agentId + ? agents.find((a) => a.id === currentChannel.agentId) + : null; const { role, @@ -70,9 +79,19 @@ function ChatBubble({ message }: ChatBubbleProps) { ? "rounded-sm border-green-400 bg-green-50/30 dark:border-green-500 dark:bg-green-900/10" : loadingStyles; - // 渲染头像,使用初始字母作为最后的备用选项 + // 渲染头像 const renderAvatar = () => { if (isUserMessage) { + // User avatar from store or fallback to ProfileIcon + if (user?.avatar) { + return ( + {user.username + ); + } return ( ); @@ -81,23 +100,34 @@ function ChatBubble({ message }: ChatBubbleProps) { if (isToolMessage) { // Tool message icon return ( -
+
🔧
); } - // AI助手头像显示首字母 - const initial = role?.charAt(0)?.toUpperCase() || "A"; + // AI agent avatar from store + if (currentAgent?.avatar) { + return ( + {currentAgent.name} + ); + } + + // Fallback to default agent avatars + const defaultAvatar = currentAgent?.tags?.includes("default_chat") + ? "/defaults/agents/avatar1.png" + : "/defaults/agents/avatar2.png"; return ( -
- {initial} -
+ Agent ); }; diff --git a/web/src/components/layouts/components/WelcomeMessage.tsx b/web/src/components/layouts/components/WelcomeMessage.tsx index 81cc1acd..4d4b6de3 100644 --- a/web/src/components/layouts/components/WelcomeMessage.tsx +++ b/web/src/components/layouts/components/WelcomeMessage.tsx @@ -1,3 +1,4 @@ +import { useXyzen } from "@/store"; import { motion } from "framer-motion"; import React from "react"; @@ -24,40 +25,49 @@ export interface Assistant { iconType: string; iconColor: string; category: string; + avatar?: string; // Agent avatar URL chats?: ChatData[]; // 与该助手的历史对话列表 } interface WelcomeMessageProps { assistant?: Assistant | null; + onQuickAction?: (action: string) => void; } -const WelcomeMessage: React.FC = ({ assistant }) => { - const iconColor = assistant?.iconColor || "indigo"; +const WelcomeMessage: React.FC = ({ + assistant, + onQuickAction, +}) => { + // Get sendMessage from store for quick actions + const sendMessage = useXyzen((state) => state.sendMessage); - // Fix dynamic class name issue by mapping to pre-defined classes - const iconColorMap: Record = { - blue: "bg-blue-50 dark:bg-blue-900/20", - green: "bg-green-50 dark:bg-green-900/20", - purple: "bg-purple-50 dark:bg-purple-900/20", - amber: "bg-amber-50 dark:bg-amber-900/20", - red: "bg-red-50 dark:bg-red-900/20", - indigo: "bg-indigo-50 dark:bg-indigo-900/20", - }; + // Quick action suggestions + const quickActions = [ + { emoji: "👋", label: "Say hello", message: "Hello! Nice to meet you." }, + { + emoji: "💡", + label: "Ask a question", + message: "Can you help me with something?", + }, + { + emoji: "📝", + label: "Start a task", + message: "I'd like to start a new task.", + }, + ]; - const iconTextColorMap: Record = { - blue: "text-blue-600 dark:text-blue-400", - green: "text-green-600 dark:text-green-400", - purple: "text-purple-600 dark:text-purple-400", - amber: "text-amber-600 dark:text-amber-400", - red: "text-red-600 dark:text-red-400", - indigo: "text-indigo-600 dark:text-indigo-400", + const handleQuickAction = (message: string) => { + if (onQuickAction) { + onQuickAction(message); + } else { + sendMessage(message); + } }; - const bgColorClass = iconColorMap[iconColor] || iconColorMap.indigo; - const textColorClass = iconTextColorMap[iconColor] || iconTextColorMap.indigo; - // Determine title and message based on whether an assistant is selected - const title = assistant ? `欢迎使用 ${assistant.title}` : "欢迎使用自由对话"; + const title = assistant + ? `Start a conversation with ${assistant.title}` + : "欢迎使用自由对话"; const description = assistant?.description || "您现在可以自由提问任何问题。无需选择特定助手,系统将根据您的问题提供合适的回复。"; @@ -69,28 +79,41 @@ const WelcomeMessage: React.FC = ({ assistant }) => { transition={{ duration: 0.6 }} className="flex flex-col items-center justify-center space-y-4 p-6 text-center" > -
- - +
+ {assistant.title} - -
+
+ ) : ( +
+ + + +
+ )} + -

+

{title} = ({ assistant }) => { {description}

+ + {/* Quick Action Buttons */} + + {quickActions.map((action, index) => ( + + ))} + ); }; diff --git a/web/src/store/slices/agentSlice.ts b/web/src/store/slices/agentSlice.ts index 9348913a..6b875a8a 100644 --- a/web/src/store/slices/agentSlice.ts +++ b/web/src/store/slices/agentSlice.ts @@ -138,6 +138,19 @@ export const createAgentSlice: StateCreator< fetchAgents: async () => { set({ agentsLoading: true }); try { + // Store existing layouts before fetching (to preserve unsaved layouts) + const existingAgents = get().agents; + const existingLayoutMap: Record = {}; + const existingAvatarMap: Record = {}; + existingAgents.forEach((agent) => { + if (agent.spatial_layout) { + existingLayoutMap[agent.id] = agent.spatial_layout; + } + if (agent.avatar) { + existingAvatarMap[agent.id] = agent.avatar; + } + }); + const response = await fetch(`${get().backendUrl}/xyzen/api/v1/agents/`, { headers: createAuthHeaders(), }); @@ -172,12 +185,18 @@ export const createAgentSlice: StateCreator< }), ); - // Enrich agents with layout and avatar from session or default + // Enrich agents with layout and avatar from: + // 1. Session (highest priority - persisted) + // 2. Existing local state (preserve unsaved changes) + // 3. Default values (fallback for new agents) const agents: AgentWithLayout[] = rawAgents.map((agent, index) => ({ ...agent, spatial_layout: - layoutMap[agent.id] ?? defaultSpatialLayoutForIndex(index), - avatar: avatarMap[agent.id] ?? agent.avatar, + layoutMap[agent.id] ?? + existingLayoutMap[agent.id] ?? + defaultSpatialLayoutForIndex(index), + avatar: + avatarMap[agent.id] ?? existingAvatarMap[agent.id] ?? agent.avatar, })); set({ @@ -453,7 +472,7 @@ export const createAgentSlice: StateCreator< } catch (fetchError) { // Session doesn't exist - create one first console.warn( - `No session found for agent ${agentId}, creating one...`, + `${fetchError} No session found for agent ${agentId}, creating one...`, ); const agent = get().agents.find((a) => a.id === agentId); const newSession = await sessionService.createSession({ diff --git a/web/src/store/slices/chatSlice.ts b/web/src/store/slices/chatSlice.ts index dcde7c55..29344edb 100644 --- a/web/src/store/slices/chatSlice.ts +++ b/web/src/store/slices/chatSlice.ts @@ -1,8 +1,24 @@ +import { generateClientId, groupToolMessagesWithAssistant } from "@/core/chat"; +import { providerCore } from "@/core/provider"; import { authService } from "@/service/authService"; -import xyzenService from "@/service/xyzenService"; import { sessionService } from "@/service/sessionService"; -import { providerCore } from "@/core/provider"; -import { generateClientId, groupToolMessagesWithAssistant } from "@/core/chat"; +import xyzenService from "@/service/xyzenService"; +import type { + AgentEndData, + AgentErrorData, + AgentExecutionState, + AgentStartData, + IterationEndData, + IterationStartData, + NodeEndData, + NodeStartData, + PhaseEndData, + PhaseExecution, + PhaseStartData, + ProgressUpdateData, + SubagentEndData, + SubagentStartData, +} from "@/types/agentEvents"; import type { StateCreator } from "zustand"; import type { ChatChannel, @@ -11,22 +27,6 @@ import type { TopicResponse, XyzenState, } from "../types"; -import type { - AgentStartData, - AgentEndData, - AgentErrorData, - PhaseStartData, - PhaseEndData, - NodeStartData, - NodeEndData, - SubagentStartData, - SubagentEndData, - ProgressUpdateData, - IterationStartData, - IterationEndData, - AgentExecutionState, - PhaseExecution, -} from "@/types/agentEvents"; // NOTE: groupToolMessagesWithAssistant and generateClientId have been moved to // @/core/chat/messageProcessor.ts as part of the frontend refactor. @@ -53,6 +53,7 @@ export interface ChatSlice { fetchChatHistory: () => Promise; togglePinChat: (chatId: string) => void; activateChannel: (topicId: string) => Promise; + activateChannelForAgent: (agentId: string) => Promise; connectToChannel: (sessionId: string, topicId: string) => void; disconnectFromChannel: () => void; sendMessage: (message: string) => Promise; @@ -395,6 +396,127 @@ export const createChatSlice: StateCreator< } }, + /** + * Activate or create a chat channel for a specific agent. + * This is used by the spatial workspace to open chat with an agent. + * - If a session exists for the agent, activates the most recent topic + * - If no session exists, creates one with a default topic + */ + activateChannelForAgent: async (agentId: string) => { + const { channels, chatHistory, backendUrl } = get(); + + // First, check if we already have a channel for this agent + const existingChannel = Object.values(channels).find( + (ch) => ch.agentId === agentId, + ); + + if (existingChannel) { + // Already have a channel, activate it + await get().activateChannel(existingChannel.id); + return; + } + + // Check chat history for existing topics with this agent + const existingHistory = chatHistory.find((h) => h.sessionId === agentId); + if (existingHistory) { + await get().activateChannel(existingHistory.id); + return; + } + + // No existing channel, try to find or create a session for this agent + const token = authService.getToken(); + if (!token) { + console.error("No authentication token available"); + return; + } + + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }; + + try { + // Try to get existing session for this agent + const sessionResponse = await fetch( + `${backendUrl}/xyzen/api/v1/sessions/by-agent/${agentId}`, + { headers }, + ); + + if (sessionResponse.ok) { + const session = await sessionResponse.json(); + + // Get the most recent topic for this session, or create one + if (session.topics && session.topics.length > 0) { + // Activate the most recent topic + const latestTopic = session.topics[session.topics.length - 1]; + + // Create channel if doesn't exist + const channel: ChatChannel = { + id: latestTopic.id, + sessionId: session.id, + title: latestTopic.name, + messages: [], + agentId: session.agent_id, + provider_id: session.provider_id, + model: session.model, + google_search_enabled: session.google_search_enabled, + connected: false, + error: null, + }; + + set((state) => { + state.channels[latestTopic.id] = channel; + }); + + await get().activateChannel(latestTopic.id); + } else { + // Session exists but no topics, create a default topic + const topicResponse = await fetch( + `${backendUrl}/xyzen/api/v1/topics/`, + { + method: "POST", + headers, + body: JSON.stringify({ + name: "新的聊天", + session_id: session.id, + }), + }, + ); + + if (topicResponse.ok) { + const newTopic = await topicResponse.json(); + + const channel: ChatChannel = { + id: newTopic.id, + sessionId: session.id, + title: newTopic.name, + messages: [], + agentId: session.agent_id, + provider_id: session.provider_id, + model: session.model, + google_search_enabled: session.google_search_enabled, + connected: false, + error: null, + }; + + set((state) => { + state.channels[newTopic.id] = channel; + }); + + await get().activateChannel(newTopic.id); + } + } + } else { + // No session exists, create one via createDefaultChannel + await get().createDefaultChannel(agentId); + } + } catch (error) { + console.error("Failed to activate channel for agent:", error); + // Fallback to createDefaultChannel + await get().createDefaultChannel(agentId); + } + }, + connectToChannel: (sessionId: string, topicId: string) => { xyzenService.disconnect(); xyzenService.connect( From f640da9952407f491b2c36dc18c579310e99c25a Mon Sep 17 00:00:00 2001 From: Harvey Date: Thu, 15 Jan 2026 03:54:17 +0800 Subject: [PATCH 10/11] feat: Update default agent avatar to use DiceBear API for consistency across components --- service/app/core/system_agent.py | 2 +- web/src/app/chat/spatial/AgentNode.tsx | 14 ++++++++++++-- web/src/components/layouts/XyzenAgent.tsx | 4 +--- web/src/components/layouts/XyzenChat.tsx | 8 ++------ .../components/layouts/components/ChatBubble.tsx | 8 ++------ web/src/components/modals/ChatPreview.tsx | 4 +--- web/src/components/modals/ShareModal.tsx | 4 +--- 7 files changed, 20 insertions(+), 24 deletions(-) diff --git a/service/app/core/system_agent.py b/service/app/core/system_agent.py index d638cb0f..92dbe913 100644 --- a/service/app/core/system_agent.py +++ b/service/app/core/system_agent.py @@ -46,7 +46,7 @@ class AgentConfig(TypedDict): 你的目标是成为用户最可靠的AI助手,帮助他们解决问题并提供有价值的信息。""", "personality": "friendly_assistant", "capabilities": ["general_chat", "qa", "assistance", "tools"], - "avatar": "/defaults/agents/avatar1.png", + "avatar": "https://api.dicebear.com/7.x/avataaars/svg?seed=default", "tags": ["助手", "对话", "工具", "帮助"], }, } diff --git a/web/src/app/chat/spatial/AgentNode.tsx b/web/src/app/chat/spatial/AgentNode.tsx index 519f07d9..a5c0dfd7 100644 --- a/web/src/app/chat/spatial/AgentNode.tsx +++ b/web/src/app/chat/spatial/AgentNode.tsx @@ -376,9 +376,19 @@ export function AgentNode({ id, data, selected }: AgentFlowNodeProps) { { // Only trigger focus if we are NOT clicking inside the settings menu interactions e.stopPropagation(); diff --git a/web/src/components/layouts/XyzenAgent.tsx b/web/src/components/layouts/XyzenAgent.tsx index 7bba3aea..9dc3e9fa 100644 --- a/web/src/components/layouts/XyzenAgent.tsx +++ b/web/src/components/layouts/XyzenAgent.tsx @@ -257,9 +257,7 @@ const AgentCard: React.FC = ({ {agent.name} diff --git a/web/src/components/modals/ChatPreview.tsx b/web/src/components/modals/ChatPreview.tsx index 19dc93cf..1a5a511f 100644 --- a/web/src/components/modals/ChatPreview.tsx +++ b/web/src/components/modals/ChatPreview.tsx @@ -45,9 +45,7 @@ const ChatPreview: React.FC = ({ // AI 机器人头像 const robotAvatarUrl = currentAgent?.avatar || - (currentAgent?.tags?.includes("default_chat") - ? "/defaults/agents/avatar1.png" - : "/defaults/agents/avatar2.png"); + "https://api.dicebear.com/7.x/avataaars/svg?seed=default"; // 用户名 const userName = currentUser?.username || "用户"; diff --git a/web/src/components/modals/ShareModal.tsx b/web/src/components/modals/ShareModal.tsx index 06238d39..747eb84e 100644 --- a/web/src/components/modals/ShareModal.tsx +++ b/web/src/components/modals/ShareModal.tsx @@ -206,9 +206,7 @@ export const ShareModal: React.FC = ({ // AI 头像逻辑 const robotAvatarUrl = currentAgent?.avatar || - (currentAgent?.tags && currentAgent.tags.includes("default_chat") - ? "/defaults/agents/avatar1.png" - : "/defaults/agents/avatar2.png"); + "https://api.dicebear.com/7.x/avataaars/svg?seed=default"; return (
Date: Thu, 15 Jan 2026 03:58:53 +0800 Subject: [PATCH 11/11] feat: Refactor ActivityPanel type and remove workspace-test references for cleaner code structure --- web/src/app/AppFullscreen.tsx | 39 ++-------------------- web/src/components/layouts/ActivityBar.tsx | 13 +------- web/src/store/slices/uiSlice/index.ts | 6 +--- 3 files changed, 4 insertions(+), 54 deletions(-) diff --git a/web/src/app/AppFullscreen.tsx b/web/src/app/AppFullscreen.tsx index 1da94aa2..b0bf1190 100644 --- a/web/src/app/AppFullscreen.tsx +++ b/web/src/app/AppFullscreen.tsx @@ -10,14 +10,11 @@ import AgentMarketplace from "@/app/marketplace/AgentMarketplace"; import { ActivityBar } from "@/components/layouts/ActivityBar"; import { AppHeader } from "@/components/layouts/AppHeader"; import KnowledgeBase from "@/components/layouts/KnowledgeBase"; -import XyzenAgent from "@/components/layouts/XyzenAgent"; -import XyzenChat from "@/components/layouts/XyzenChat"; import { PwaInstallPrompt } from "@/components/features/PwaInstallPrompt"; import { SettingsModal } from "@/components/modals/SettingsModal"; import { DEFAULT_BACKEND_URL } from "@/configs"; -import { useTranslation } from "react-i18next"; export interface AppFullscreenProps { backendUrl?: string; @@ -26,7 +23,6 @@ export interface AppFullscreenProps { export function AppFullscreen({ backendUrl = DEFAULT_BACKEND_URL, }: AppFullscreenProps) { - const { t } = useTranslation(); const { setBackendUrl, // centralized UI actions @@ -80,33 +76,8 @@ export function AppFullscreen({ {/* Panel Content */}
{activePanel === "chat" && ( -
- {/* Left Sidebar: Assistants - Only show if no active chat or we want a split view? - In fullscreen, we typically want the list AND the chat. - However, to match AppSide logic where we drill down: - */} - - {/* For fullscreen, we can keep the sidebar + chat layout for the "chat" panel */} - - - {/* Right Column: Chat Interface */} -
- -
+
+
)} @@ -121,12 +92,6 @@ export function AppFullscreen({
)} - - {activePanel === "workspace-test" && ( -
- -
- )}
diff --git a/web/src/components/layouts/ActivityBar.tsx b/web/src/components/layouts/ActivityBar.tsx index affb1d71..7f1e39a2 100644 --- a/web/src/components/layouts/ActivityBar.tsx +++ b/web/src/components/layouts/ActivityBar.tsx @@ -2,16 +2,11 @@ import { ChatBubbleLeftRightIcon, FolderIcon, SparklesIcon, - Squares2X2Icon, } from "@heroicons/react/24/outline"; import { motion } from "framer-motion"; import { useTranslation } from "react-i18next"; -export type ActivityPanel = - | "chat" - | "knowledge" - | "marketplace" - | "workspace-test"; +export type ActivityPanel = "chat" | "knowledge" | "marketplace"; interface ActivityBarProps { activePanel: ActivityPanel; @@ -121,12 +116,6 @@ export const ActivityBar: React.FC = ({ label: t("app.activityBar.community"), disabled: false, }, - { - panel: "workspace-test" as ActivityPanel, - icon: Squares2X2Icon, - label: "Workspace Concept", - disabled: false, - }, ]; return ( diff --git a/web/src/store/slices/uiSlice/index.ts b/web/src/store/slices/uiSlice/index.ts index 85d9bf0b..5cbaa024 100644 --- a/web/src/store/slices/uiSlice/index.ts +++ b/web/src/store/slices/uiSlice/index.ts @@ -8,11 +8,7 @@ import { type InputPosition, type LayoutStyle } from "./types"; // Ensure xyzen service is aware of the default backend on startup xyzenService.setBackendUrl(DEFAULT_BACKEND_URL); -export type ActivityPanel = - | "chat" - | "knowledge" - | "marketplace" - | "workspace-test"; +export type ActivityPanel = "chat" | "knowledge" | "marketplace"; export interface UiSlice { backendUrl: string;