Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions invokeai/app/api/routers/boards.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ async def list_boards(
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
include_archived: bool = Query(default=False, description="Whether or not to include archived boards in list"),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards for the current user, including shared boards"""
"""Gets a list of boards for the current user, including shared boards. Admin users see all boards."""
if all:
return ApiDependencies.invoker.services.boards.get_all(
current_user.user_id, order_by, direction, include_archived
current_user.user_id, current_user.is_admin, order_by, direction, include_archived
)
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(
current_user.user_id, order_by, direction, offset, limit, include_archived
current_user.user_id, current_user.is_admin, order_by, direction, offset, limit, include_archived
)
else:
raise HTTPException(
Expand Down
18 changes: 16 additions & 2 deletions invokeai/app/api/routers/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ async def upload_image(
workflow=extracted_metadata.invokeai_workflow,
graph=extracted_metadata.invokeai_graph,
is_intermediate=is_intermediate,
user_id=current_user.user_id,
)

response.status_code = 201
Expand Down Expand Up @@ -375,6 +376,7 @@ async def get_image_urls(
response_model=OffsetPaginatedResults[ImageDTO],
)
async def list_image_dtos(
current_user: CurrentUser,
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
Expand All @@ -388,10 +390,19 @@ async def list_image_dtos(
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs"""
"""Gets a list of image DTOs for the current user"""

image_dtos = ApiDependencies.invoker.services.images.get_many(
offset, limit, starred_first, order_dir, image_origin, categories, is_intermediate, board_id, search_term
offset,
limit,
starred_first,
order_dir,
image_origin,
categories,
is_intermediate,
board_id,
search_term,
current_user.user_id,
)

return image_dtos
Expand Down Expand Up @@ -569,6 +580,7 @@ async def get_bulk_download_item(

@images_router.get("/names", operation_id="get_image_names")
async def get_image_names(
current_user: CurrentUser,
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
Expand All @@ -591,6 +603,8 @@ async def get_image_names(
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
user_id=current_user.user_id,
is_admin=current_user.is_admin,
)
return result
except Exception:
Expand Down
6 changes: 4 additions & 2 deletions invokeai/app/services/board_records/board_records_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,24 @@ def update(
def get_many(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records for a specific user, including shared boards."""
"""Gets many board records for a specific user, including shared boards. Admin users see all boards."""
pass

@abstractmethod
def get_all(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
include_archived: bool = False,
) -> list[BoardRecord]:
"""Gets all board records for a specific user, including shared boards."""
"""Gets all board records for a specific user, including shared boards. Admin users see all boards."""
pass
5 changes: 5 additions & 0 deletions invokeai/app/services/board_records/board_records_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class BoardRecord(BaseModelExcludeNull):
"""The unique ID of the board."""
board_name: str = Field(description="The name of the board.")
"""The name of the board."""
user_id: str = Field(description="The user ID of the board owner.")
"""The user ID of the board owner."""
created_at: Union[datetime, str] = Field(description="The created timestamp of the board.")
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
Expand All @@ -35,6 +37,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:

board_id = board_dict.get("board_id", "unknown")
board_name = board_dict.get("board_name", "unknown")
# Default to 'system' for backwards compatibility with boards created before multiuser support
user_id = board_dict.get("user_id", "system")
cover_image_name = board_dict.get("cover_image_name", "unknown")
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
Expand All @@ -44,6 +48,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
return BoardRecord(
board_id=board_id,
board_name=board_name,
user_id=user_id,
cover_image_name=cover_image_name,
created_at=created_at,
updated_at=updated_at,
Expand Down
110 changes: 85 additions & 25 deletions invokeai/app/services/board_records/board_records_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,35 @@ def update(
def get_many(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
with self._db.transaction() as cursor:
# Build base query - include boards owned by user, shared with user, or public
base_query = """
# Build base query - admins see all boards, regular users see owned, shared, or public boards
if is_admin:
base_query = """
SELECT DISTINCT boards.*
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""

# Determine archived filter condition
archived_filter = "WHERE 1=1" if include_archived else "WHERE boards.archived = 0"

final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)

# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
else:
base_query = """
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
Expand All @@ -141,38 +161,52 @@ def get_many(
LIMIT ? OFFSET ?;
"""

# Determine archived filter condition
archived_filter = "" if include_archived else "AND boards.archived = 0"
# Determine archived filter condition
archived_filter = "" if include_archived else "AND boards.archived = 0"

final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)

# Execute query to fetch boards
cursor.execute(final_query, (user_id, user_id, limit, offset))
# Execute query to fetch boards
cursor.execute(final_query, (user_id, user_id, limit, offset))

result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]

# Determine count query - count boards accessible to user
if include_archived:
count_query = """
# Determine count query - admins count all boards, regular users count accessible boards
if is_admin:
if include_archived:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
WHERE boards.archived = 0;
"""
cursor.execute(count_query)
else:
if include_archived:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1);
"""
else:
count_query = """
else:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
AND boards.archived = 0;
"""

# Execute count query
cursor.execute(count_query, (user_id, user_id))
# Execute count query
cursor.execute(count_query, (user_id, user_id))

count = cast(int, cursor.fetchone()[0])

Expand All @@ -181,22 +215,48 @@ def get_many(
def get_all(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
include_archived: bool = False,
) -> list[BoardRecord]:
with self._db.transaction() as cursor:
if order_by == BoardRecordOrderBy.Name:
base_query = """
# Build query - admins see all boards, regular users see owned, shared, or public boards
if is_admin:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT DISTINCT boards.*
FROM boards
{archived_filter}
ORDER BY LOWER(boards.board_name) {direction}
"""
else:
base_query = """
SELECT DISTINCT boards.*
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""

archived_filter = "WHERE 1=1" if include_archived else "WHERE boards.archived = 0"

final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)

cursor.execute(final_query)
else:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
{archived_filter}
ORDER BY LOWER(boards.board_name) {direction}
"""
else:
base_query = """
else:
base_query = """
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
Expand All @@ -205,13 +265,13 @@ def get_all(
ORDER BY {order_by} {direction}
"""

archived_filter = "" if include_archived else "AND boards.archived = 0"
archived_filter = "" if include_archived else "AND boards.archived = 0"

final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)

cursor.execute(final_query, (user_id, user_id))
cursor.execute(final_query, (user_id, user_id))

result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
Expand Down
6 changes: 4 additions & 2 deletions invokeai/app/services/boards/boards_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,24 @@ def delete(
def get_many(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards for a specific user, including shared boards."""
"""Gets many boards for a specific user, including shared boards. Admin users see all boards."""
pass

@abstractmethod
def get_all(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
include_archived: bool = False,
) -> list[BoardDTO]:
"""Gets all boards for a specific user, including shared boards."""
"""Gets all boards for a specific user, including shared boards. Admin users see all boards."""
pass
9 changes: 8 additions & 1 deletion invokeai/app/services/boards/boards_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,22 @@ class BoardDTO(BoardRecord):
"""The number of images in the board."""
asset_count: int = Field(description="The number of assets in the board.")
"""The number of assets in the board."""
owner_username: Optional[str] = Field(default=None, description="The username of the board owner (for admin view).")
"""The username of the board owner (for admin view)."""


def board_record_to_dto(
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int, asset_count: int
board_record: BoardRecord,
cover_image_name: Optional[str],
image_count: int,
asset_count: int,
owner_username: Optional[str] = None,
) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.model_dump(exclude={"cover_image_name"}),
cover_image_name=cover_image_name,
image_count=image_count,
asset_count=asset_count,
owner_username=owner_username,
)
Loading
Loading