Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions service/app/api/v1/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,31 @@ async def create_session_with_default_topic(
raise handle_auth_error(e)


@router.get("/by-agent/{agent_id}", response_model=SessionRead)
@router.get("/by-agent/{agent_id}", response_model=SessionReadWithTopics)
async def get_session_by_agent(
agent_id: str, user: str = Depends(get_current_user), db: AsyncSession = Depends(get_session)
) -> SessionRead:
) -> SessionReadWithTopics:
"""
Retrieve a session for the current user with a specific agent.
Retrieve a session for the current user with a specific agent, including topics.

Finds a session associated with the given agent ID for the authenticated user.
The agent_id can be "default" for sessions without an agent, a UUID string
for sessions with a specific agent, or a builtin agent string ID.
Topics are ordered by updated_at descending (most recent first).

Args:
agent_id: Agent identifier ("default", UUID string, or builtin agent ID)
user: Authenticated user ID (injected by dependency)
db: Database session (injected by dependency)

Returns:
SessionRead: The session associated with the user and agent
SessionReadWithTopics: The session with topics associated with the user and agent

Raises:
HTTPException: 404 if no session found for this user-agent combination
"""
try:
return await SessionService(db).get_session_by_agent(user, agent_id)
return await SessionService(db).get_session_by_agent_with_topics(user, agent_id)
except ErrCodeError as e:
raise handle_auth_error(e)

Expand Down
14 changes: 14 additions & 0 deletions service/app/core/session/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ async def get_session_by_agent(self, user_id: str, agent_id: str) -> SessionRead
raise ErrCode.SESSION_NOT_FOUND.with_messages("No session found for this user-agent combination")
return SessionRead(**session.model_dump())

async def get_session_by_agent_with_topics(self, user_id: str, agent_id: str) -> SessionReadWithTopics:
agent_uuid = await self._resolve_agent_uuid_for_lookup(agent_id)
session = await self.session_repo.get_session_by_user_and_agent(user_id, agent_uuid)
if not session:
raise ErrCode.SESSION_NOT_FOUND.with_messages("No session found for this user-agent combination")

# Fetch topics ordered by updated_at descending (most recent first)
topics = await self.topic_repo.get_topics_by_session(session.id, order_by_updated=True)
topic_reads = [TopicRead(**topic.model_dump()) for topic in topics]

session_dict = session.model_dump()
session_dict["topics"] = topic_reads
return SessionReadWithTopics(**session_dict)

async def get_sessions_with_topics(self, user_id: str) -> list[SessionReadWithTopics]:
sessions = await self.session_repo.get_sessions_by_user_ordered_by_activity(user_id)

Expand Down
167 changes: 105 additions & 62 deletions service/app/tools/builtin/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from uuid import UUID, uuid4

from langchain_core.tools import BaseTool, StructuredTool
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from app.configs import configs
from app.core.storage import FileScope, generate_storage_key, get_storage_service
Expand All @@ -24,6 +24,9 @@

# --- Input Schemas ---

# Maximum number of reference images allowed for generation
MAX_INPUT_IMAGES = 4


class GenerateImageInput(BaseModel):
"""Input schema for generate_image tool."""
Expand All @@ -35,14 +38,24 @@ class GenerateImageInput(BaseModel):
default="1:1",
description="Aspect ratio of the generated image.",
)
image_id: str | None = Field(
image_ids: list[str] | None = Field(
default=None,
description=(
"Optional image UUID to use as a reference input. "
"Use the 'image_id' value returned from generate_image or upload tools."
f"Optional list of image UUIDs (max {MAX_INPUT_IMAGES}) to use as reference inputs. "
"Use the 'image_id' values returned from generate_image or upload tools."
),
)

@model_validator(mode="after")
def validate_image_inputs(self) -> "GenerateImageInput":
"""Validate image_ids field."""
if self.image_ids:
if len(self.image_ids) > MAX_INPUT_IMAGES:
raise ValueError(f"Maximum {MAX_INPUT_IMAGES} input images allowed, got {len(self.image_ids)}")
if len(self.image_ids) == 0:
self.image_ids = None # Normalize empty list to None
return self


class ReadImageInput(BaseModel):
"""Input schema for read_image tool."""
Expand All @@ -62,8 +75,7 @@ class ReadImageInput(BaseModel):
async def _generate_image_with_langchain(
prompt: str,
aspect_ratio: str = "1:1",
image_bytes: bytes | None = None,
image_mime_type: str | None = None,
images: list[tuple[bytes, str]] | None = None,
) -> tuple[bytes, str]:
"""
Generate an image using LangChain ChatGoogleGenerativeAI via ProviderManager.
Expand All @@ -76,6 +88,7 @@ async def _generate_image_with_langchain(
Args:
prompt: Text description of the image to generate
aspect_ratio: Aspect ratio for the generated image
images: Optional list of (image_bytes, mime_type) tuples to use as references

Returns:
Tuple of (image_bytes, mime_type)
Expand All @@ -102,25 +115,34 @@ async def _generate_image_with_langchain(
)

# Request image generation via LangChain
if image_bytes and image_mime_type:
b64_data = base64.b64encode(image_bytes).decode("utf-8")
message = HumanMessage(
content=[
if images:
# Build content array with multiple image_url blocks
content: list[dict[str, Any]] = []
for image_bytes, image_mime_type in images:
b64_data = base64.b64encode(image_bytes).decode("utf-8")
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:{image_mime_type};base64,{b64_data}",
},
},
{
"type": "text",
"text": (
"Use the provided image as a reference. "
f"Generate a new image with aspect ratio {aspect_ratio}: {prompt}"
),
},
]
}
)

# Add text prompt with appropriate phrasing for single vs multiple images
image_count = len(images)
if image_count == 1:
reference_text = "Use the provided image as a reference."
else:
reference_text = f"Use these {image_count} provided images as references."

content.append(
{
"type": "text",
"text": f"{reference_text} Generate a new image with aspect ratio {aspect_ratio}: {prompt}",
}
)
message = HumanMessage(content=content) # type: ignore[arg-type]
else:
message = HumanMessage(content=f"Generate an image with aspect ratio {aspect_ratio}: {prompt}")
response = await llm.ainvoke([message])
Expand Down Expand Up @@ -172,46 +194,67 @@ async def _generate_image_with_langchain(
raise ValueError("No image data in response. Model may not support image generation.")


async def _load_image_for_generation(user_id: str, image_id: str) -> tuple[bytes, str, str]:
async def _load_images_for_generation(user_id: str, image_ids: list[str]) -> list[tuple[bytes, str, str]]:
"""
Load multiple images for generation from the database.

Args:
user_id: User ID for permission check
image_ids: List of image UUIDs to load

Returns:
List of tuples: (image_bytes, mime_type, storage_key)

Raises:
ValueError: If any image_id is invalid, not found, deleted, or inaccessible
"""
from app.infra.database import create_task_session_factory
from app.repos.file import FileRepository

try:
file_uuid = UUID(image_id)
except ValueError as exc:
raise ValueError(f"Invalid image_id format: {image_id}") from exc
results: list[tuple[bytes, str, str]] = []

# Create a fresh session factory for the current event loop (Celery worker)
TaskSessionLocal = create_task_session_factory()

async with TaskSessionLocal() as db:
file_repo = FileRepository(db)
file_record = await file_repo.get_file_by_id(file_uuid)
storage = get_storage_service()

if file_record is None:
raise ValueError(f"Image not found: {image_id}")
for image_id in image_ids:
try:
file_uuid = UUID(image_id)
except ValueError as exc:
raise ValueError(f"Invalid image_id format: {image_id}") from exc

if file_record.is_deleted:
raise ValueError(f"Image has been deleted: {image_id}")
file_record = await file_repo.get_file_by_id(file_uuid)

if file_record.user_id != user_id and file_record.scope != "public":
raise ValueError("Permission denied: you don't have access to this image")
if file_record is None:
raise ValueError(f"Image not found: {image_id}")

storage_key = file_record.storage_key
content_type = file_record.content_type or "image/png"
if file_record.is_deleted:
raise ValueError(f"Image has been deleted: {image_id}")

storage = get_storage_service()
buffer = io.BytesIO()
await storage.download_file(storage_key, buffer)
image_bytes = buffer.getvalue()
return image_bytes, content_type, storage_key
if file_record.user_id != user_id and file_record.scope != "public":
raise ValueError(f"Permission denied: you don't have access to image {image_id}")

storage_key = file_record.storage_key
content_type = file_record.content_type or "image/png"

# Download from storage
buffer = io.BytesIO()
await storage.download_file(storage_key, buffer)
image_bytes = buffer.getvalue()

results.append((image_bytes, content_type, storage_key))

return results


async def _generate_image(
user_id: str,
prompt: str,
aspect_ratio: str = "1:1",
image_id: str | None = None,
image_ids: list[str] | None = None,
) -> dict[str, Any]:
"""
Generate an image and store it to OSS, then register in database.
Expand All @@ -220,28 +263,27 @@ async def _generate_image(
user_id: User ID for storage organization
prompt: Image description
aspect_ratio: Aspect ratio for the image
image_ids: Optional list of image UUIDs to use as reference inputs

Returns:
Dictionary with success status, path, URL, and metadata
"""
try:
# Load optional reference image
source_image_bytes = None
source_mime_type = None
source_storage_key = None
source_image_id = image_id
if source_image_id:
source_image_bytes, source_mime_type, source_storage_key = await _load_image_for_generation(
user_id,
source_image_id,
)
# Load optional reference images
images_for_generation: list[tuple[bytes, str]] | None = None
source_storage_keys: list[str] = []
source_image_ids: list[str] = image_ids or []

if source_image_ids:
loaded_images = await _load_images_for_generation(user_id, source_image_ids)
images_for_generation = [(img[0], img[1]) for img in loaded_images]
source_storage_keys = [img[2] for img in loaded_images]

# Generate image using LangChain via ProviderManager
image_bytes, mime_type = await _generate_image_with_langchain(
prompt,
aspect_ratio,
image_bytes=source_image_bytes,
image_mime_type=source_mime_type,
images=images_for_generation,
)

# Determine file extension from mime type
Expand Down Expand Up @@ -290,27 +332,27 @@ async def _generate_image(
metainfo={
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"source_image_id": source_image_id,
"source_storage_key": source_storage_key,
"source_image_ids": source_image_ids,
"source_storage_keys": source_storage_keys,
},
)
file_record = await file_repo.create_file(file_data)
await db.commit()
# Refresh to get the generated UUID
await db.refresh(file_record)
image_id = str(file_record.id)
generated_image_id = str(file_record.id)

logger.info(f"Generated image for user {user_id}: {storage_key} (id={image_id})")
logger.info(f"Generated image for user {user_id}: {storage_key} (id={generated_image_id})")

return {
"success": True,
"image_id": image_id,
"image_id": generated_image_id,
"path": storage_key,
"url": url,
"markdown": f"![Generated Image]({url})",
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"source_image_id": source_image_id,
"source_image_ids": source_image_ids,
"mime_type": mime_type,
"size_bytes": len(image_bytes),
}
Expand Down Expand Up @@ -511,7 +553,7 @@ def create_image_tools() -> dict[str, BaseTool]:
async def generate_image_placeholder(
prompt: str,
aspect_ratio: str = "1:1",
image_id: str | None = None,
image_ids: list[str] | None = None,
) -> dict[str, Any]:
return {"error": "Image tools require agent context binding", "success": False}

Expand All @@ -520,7 +562,7 @@ async def generate_image_placeholder(
description=(
"Generate an image based on a text description. "
"Provide a detailed prompt describing the desired image. "
"To modify or generate based on a previous image, pass the 'image_id' from a previous generate_image result. "
f"To generate based on previous images, pass 'image_ids' with up to {MAX_INPUT_IMAGES} reference image UUIDs. "
"Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image. "
"TIP: You can use 'image_id' values when creating PPTX presentations with knowledge_write - see knowledge_help(topic='image_slides') for details."
),
Expand Down Expand Up @@ -564,17 +606,17 @@ def create_image_tools_for_agent(user_id: str) -> list[BaseTool]:
async def generate_image_bound(
prompt: str,
aspect_ratio: str = "1:1",
image_id: str | None = None,
image_ids: list[str] | None = None,
) -> dict[str, Any]:
return await _generate_image(user_id, prompt, aspect_ratio, image_id)
return await _generate_image(user_id, prompt, aspect_ratio, image_ids)

tools.append(
StructuredTool(
name="generate_image",
description=(
"Generate an image based on a text description. "
"Provide a detailed prompt describing the desired image including style, colors, composition, and subject. "
"To modify or generate based on a previous image, pass the 'image_id' from a previous generate_image result. "
f"To generate based on previous images, pass 'image_ids' with up to {MAX_INPUT_IMAGES} reference image UUIDs. "
"Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image to the user. "
"TIP: You can use 'image_id' values when creating beautiful PPTX presentations with knowledge_write in image_slides mode - call knowledge_help(topic='image_slides') for the full workflow."
),
Expand Down Expand Up @@ -611,4 +653,5 @@ async def read_image_bound(
"create_image_tools_for_agent",
"GenerateImageInput",
"ReadImageInput",
"MAX_INPUT_IMAGES",
]
7 changes: 4 additions & 3 deletions service/app/tools/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ def calculate_tool_cost(
config = tool_info.cost
cost = config.base_cost

# Add input image cost (for generate_image with reference)
# Add input image cost (for generate_image with reference images)
if config.input_image_cost and tool_args:
if tool_args.get("image_id"): # Has reference image
cost += config.input_image_cost
image_ids = tool_args.get("image_ids")
if image_ids:
cost += config.input_image_cost * len(image_ids)

# Add output file cost (for knowledge_write creating new files)
if config.output_file_cost and tool_result:
Expand Down
Loading