Skip to content

Commit 10e1b09

Browse files
author
Attila Cseh
committed
queue list virtualized
1 parent 704b405 commit 10e1b09

File tree

18 files changed

+455
-120
lines changed

18 files changed

+455
-120
lines changed

invokeai/app/api/routers/session_queue.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,9 @@ async def enqueue_batch(
8080
)
8181
async def list_queue_items(
8282
queue_id: str = Path(description="The queue id to perform this operation on"),
83-
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
84-
limit: int = Query(default=50, description="The number of queues per page"),
85-
order_by: QUEUE_ORDER_BY = Query(default="item_id", description="The status of items to fetch"),
86-
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
83+
limit: int = Query(default=50, description="The number of items to fetch"),
8784
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
85+
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
8886
priority: int = Query(default=0, description="The pagination cursor priority"),
8987
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
9088
) -> CursorPaginatedResults[SessionQueueItem]:
@@ -131,9 +129,9 @@ async def list_all_queue_items(
131129
200: {"model": ItemIdsResult},
132130
},
133131
)
134-
async def get_queue_itemIds(
132+
async def get_queue_item_ids(
135133
queue_id: str = Path(description="The queue id to perform this operation on"),
136-
order_by: QUEUE_ORDER_BY = Query(default="ITEM_ID", description="The sort field"),
134+
order_by: QUEUE_ORDER_BY = Query(default="completed_at", description="The sort field"),
137135
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
138136
) -> ItemIdsResult:
139137
"""Gets all queue item ids that match the given parameters"""
@@ -145,6 +143,36 @@ async def get_queue_itemIds(
145143
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")
146144

147145

146+
@session_queue_router.post(
147+
"/{queue_id}/items_by_ids",
148+
operation_id="get_queue_items_by_item_ids",
149+
responses={200: {"model": list[SessionQueueItem]}},
150+
)
151+
async def get_queue_items_by_item_ids(
152+
queue_id: str = Path(description="The queue id to perform this operation on"),
153+
item_ids: list[int] = Body(
154+
embed=True, description="Object containing list of queue item ids to fetch queue items for"
155+
),
156+
) -> list[SessionQueueItem]:
157+
"""Gets queue items for the specified queue item ids. Maintains order of item ids."""
158+
try:
159+
session_queue_service = ApiDependencies.invoker.services.session_queue
160+
161+
# Fetch queue items preserving the order of requested item ids
162+
queue_items: list[SessionQueueItem] = []
163+
for item_id in item_ids:
164+
try:
165+
queue_item = session_queue_service.get_queue_item(item_id)
166+
queue_items.append(queue_item)
167+
except Exception:
168+
# Skip missing queue items - they may have been deleted between item id fetch and queue item fetch
169+
continue
170+
171+
return queue_items
172+
except Exception:
173+
raise HTTPException(status_code=500, detail="Failed to get queue items")
174+
175+
148176
@session_queue_router.put(
149177
"/{queue_id}/processor/resume",
150178
operation_id="resume",

invokeai/app/services/session_queue/session_queue_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def list_all_queue_items(
164164
def get_queue_itemIds(
165165
self,
166166
queue_id: str,
167-
order_by: QUEUE_ORDER_BY = "item_id",
167+
order_by: QUEUE_ORDER_BY = "completed_at",
168168
order_dir: SQLiteDirection = SQLiteDirection.Descending,
169169
) -> ItemIdsResult:
170170
"""Gets all queue item ids that match the given parameters"""

invokeai/app/services/session_queue/session_queue_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def validate_graph(cls, v: Graph):
174174

175175
DEFAULT_QUEUE_ID = "default"
176176

177-
QUEUE_ORDER_BY = Literal["item_id", "status", "created_at"]
177+
QUEUE_ORDER_BY = Literal["status", "completed_at"]
178178
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
179179

180180

invokeai/app/services/session_queue/session_queue_sqlite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def list_all_queue_items(
677677
def get_queue_itemIds(
678678
self,
679679
queue_id: str,
680-
order_by: QUEUE_ORDER_BY = "item_id",
680+
order_by: QUEUE_ORDER_BY = "completed_at",
681681
order_dir: SQLiteDirection = SQLiteDirection.Descending,
682682
) -> ItemIdsResult:
683683
with self._db.transaction() as cursor_:

invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import type { ChakraProps, CollapseProps } from '@invoke-ai/ui-library';
2-
import { Badge, ButtonGroup, Collapse, Flex, IconButton, Text } from '@invoke-ai/ui-library';
1+
import type { ChakraProps, CollapseProps, FlexProps } from '@invoke-ai/ui-library';
2+
import { Badge, ButtonGroup, Collapse, Flex, Icon, IconButton, Text } from '@invoke-ai/ui-library';
33
import QueueStatusBadge from 'features/queue/components/common/QueueStatusBadge';
44
import { useDestinationText } from 'features/queue/components/QueueList/useDestinationText';
55
import { useOriginText } from 'features/queue/components/QueueList/useOriginText';
@@ -11,7 +11,7 @@ import { selectShouldShowCredits } from 'features/system/store/configSlice';
1111
import type { MouseEvent } from 'react';
1212
import { memo, useCallback, useMemo } from 'react';
1313
import { useTranslation } from 'react-i18next';
14-
import { PiArrowCounterClockwiseBold, PiXBold } from 'react-icons/pi';
14+
import { PiArrowCounterClockwiseBold, PiImageBold, PiXBold } from 'react-icons/pi';
1515
import { useSelector } from 'react-redux';
1616
import type { S } from 'services/api/types';
1717

@@ -154,3 +154,11 @@ const transition: CollapseProps['transition'] = {
154154
};
155155

156156
export default memo(QueueItemComponent);
157+
158+
export const QueueItemPlaceholder = memo((props: FlexProps) => (
159+
<Flex w="full" h="full" bg="base.850" borderRadius="base" alignItems="center" justifyContent="center" {...props}>
160+
<Icon as={PiImageBold} boxSize={16} color="base.800" />
161+
</Flex>
162+
));
163+
164+
QueueItemPlaceholder.displayName = 'QueueItemPlaceholder';

invokeai/frontend/web/src/features/queue/components/QueueList/QueueList.tsx

Lines changed: 103 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,114 @@
1-
import { Flex, Heading } from '@invoke-ai/ui-library';
2-
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
1+
import { Flex, Heading, ListItem } from '@invoke-ai/ui-library';
32
import { IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
4-
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
5-
import {
6-
listCursorChanged,
7-
listPriorityChanged,
8-
selectQueueListCursor,
9-
selectQueueListPriority,
10-
} from 'features/queue/store/queueSlice';
11-
import { useOverlayScrollbars } from 'overlayscrollbars-react';
12-
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
3+
import { useRangeBasedQueueItemFetching } from 'features/queue/hooks/useRangeBasedQueueItemFetching';
4+
import { memo, useCallback, useMemo, useRef, useState } from 'react';
135
import { useTranslation } from 'react-i18next';
14-
import type { Components, ItemContent } from 'react-virtuoso';
6+
import type {
7+
Components,
8+
Components,
9+
Components,
10+
ComputeItemKey,
11+
ItemContent,
12+
ListRange,
13+
ScrollSeekConfiguration,
14+
VirtuosoHandle,
15+
} from 'react-virtuoso';
1516
import { Virtuoso } from 'react-virtuoso';
16-
import { queueItemsAdapterSelectors, useListQueueItemsQuery } from 'services/api/endpoints/queue';
17-
import type { S } from 'services/api/types';
17+
import { queueApi } from 'services/api/endpoints/queue';
1818

19-
import QueueItemComponent from './QueueItemComponent';
19+
import QueueItemComponent, { QueueItemPlaceholder } from './QueueItemComponent';
2020
import QueueListComponent from './QueueListComponent';
2121
import QueueListHeader from './QueueListHeader';
2222
import type { ListContext } from './types';
23+
import { useQueueItemIds } from './useQueueItemIds';
24+
import { useScrollableQueueList } from './useScrollableQueueList';
25+
26+
const QueueItemAtPosition = memo(
27+
({ index, itemId, context }: { index: number; itemId: number; context: ListContext }) => {
28+
/*
29+
* We rely on the useRangeBasedQueueItemFetching to fetch all queue items, caching them with RTK Query.
30+
*
31+
* In this component, we just want to consume that cache. Unforutnately, RTK Query does not provide a way to
32+
* subscribe to a query without triggering a new fetch.
33+
*
34+
* There is a hack, though:
35+
* - https://github.com/reduxjs/redux-toolkit/discussions/4213
36+
*
37+
* This essentially means "subscribe to the query once it has some data".
38+
*/
39+
40+
// Use `currentData` instead of `data` to prevent a flash of previous queue item rendered at this index
41+
const { currentData: queueItem, isUninitialized } = queueApi.endpoints.getQueueItem.useQueryState(itemId);
42+
queueApi.endpoints.getQueueItem.useQuerySubscription(itemId, { skip: isUninitialized });
43+
44+
if (!queueItem) {
45+
return <QueueItemPlaceholder item-id={itemId} />;
46+
}
47+
48+
return <QueueItemComponent index={index} item={queueItem} context={context} />;
49+
}
50+
);
51+
QueueItemAtPosition.displayName = 'QueueItemAtPosition';
52+
53+
const computeItemKey: ComputeItemKey<number, ListContext> = (index, itemId, { queryArgs }) => {
54+
return `${JSON.stringify(queryArgs)}-${itemId ?? index}`;
55+
};
2356

24-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
25-
type TableVirtuosoScrollerRef = (ref: HTMLElement | Window | null) => any;
57+
const itemContent: ItemContent<number, ListContext> = (index, itemId, context) => (
58+
<QueueItemAtPosition index={index} itemId={itemId} context={context} />
59+
);
60+
61+
const ScrollSeekPlaceholderComponent: Components<ListContext>['ScrollSeekPlaceholder'] = (props) => (
62+
<ListItem aspectRatio="1/1" {...props}>
63+
<QueueItemPlaceholder />
64+
</ListItem>
65+
);
2666

27-
const computeItemKey = (index: number, item: S['SessionQueueItem']): number => item.item_id;
67+
ScrollSeekPlaceholderComponent.displayName = 'ScrollSeekPlaceholderComponent';
2868

29-
const components: Components<S['SessionQueueItem'], ListContext> = {
69+
const components: Components<number, ListContext> = {
3070
List: QueueListComponent,
71+
// ScrollSeekPlaceholder: ScrollSeekPlaceholderComponent,
3172
};
3273

33-
const itemContent: ItemContent<S['SessionQueueItem'], ListContext> = (index, item, context) => (
34-
<QueueItemComponent index={index} item={item} context={context} />
35-
);
74+
const scrollSeekConfiguration: ScrollSeekConfiguration = {
75+
enter: (velocity) => {
76+
return Math.abs(velocity) > 2048;
77+
},
78+
exit: (velocity) => {
79+
return velocity === 0;
80+
},
81+
};
3682

37-
const QueueList = () => {
38-
const listCursor = useAppSelector(selectQueueListCursor);
39-
const listPriority = useAppSelector(selectQueueListPriority);
40-
const dispatch = useAppDispatch();
83+
export const QueueList = () => {
84+
const virtuosoRef = useRef<VirtuosoHandle>(null);
85+
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
4186
const rootRef = useRef<HTMLDivElement>(null);
42-
const [scroller, setScroller] = useState<HTMLElement | null>(null);
43-
const [initialize, osInstance] = useOverlayScrollbars(overlayScrollbarsParams);
4487
const { t } = useTranslation();
4588

46-
useEffect(() => {
47-
const { current: root } = rootRef;
48-
if (scroller && root) {
49-
initialize({
50-
target: root,
51-
elements: {
52-
viewport: scroller,
53-
},
54-
});
55-
}
56-
return () => osInstance()?.destroy();
57-
}, [scroller, initialize, osInstance]);
58-
59-
const { data: listQueueItemsData, isLoading } = useListQueueItemsQuery(
60-
{
61-
cursor: listCursor,
62-
priority: listPriority,
89+
// Get the ordered list of queue item ids - this is our primary data source for virtualization
90+
const { queryArgs, itemIds, isLoading } = useQueueItemIds();
91+
92+
// Use range-based fetching for bulk loading queue items into cache based on the visible range
93+
const { onRangeChanged } = useRangeBasedQueueItemFetching({
94+
itemIds,
95+
enabled: !isLoading,
96+
});
97+
98+
const scrollerRef = useScrollableQueueList(rootRef) as (ref: HTMLElement | Window | null) => void;
99+
100+
/*
101+
* We have to keep track of the visible range for keep-selected-image-in-view functionality and push the range to
102+
* the range-based queue item fetching hook.
103+
*/
104+
const handleRangeChanged = useCallback(
105+
(range: ListRange) => {
106+
rangeRef.current = range;
107+
onRangeChanged(range);
63108
},
64-
{
65-
refetchOnMountOrArgChange: true,
66-
}
109+
[onRangeChanged]
67110
);
68111

69-
const queueItems = useMemo(() => {
70-
if (!listQueueItemsData) {
71-
return [];
72-
}
73-
return queueItemsAdapterSelectors.selectAll(listQueueItemsData);
74-
}, [listQueueItemsData]);
75-
76-
const handleLoadMore = useCallback(() => {
77-
if (!listQueueItemsData?.has_more) {
78-
return;
79-
}
80-
const lastItem = queueItems[queueItems.length - 1];
81-
if (!lastItem) {
82-
return;
83-
}
84-
dispatch(listCursorChanged(lastItem.item_id));
85-
dispatch(listPriorityChanged(lastItem.priority));
86-
}, [dispatch, listQueueItemsData?.has_more, queueItems]);
87-
88112
const [openQueueItems, setOpenQueueItems] = useState<number[]>([]);
89113

90114
const toggleQueueItem = useCallback((item_id: number) => {
@@ -96,13 +120,16 @@ const QueueList = () => {
96120
});
97121
}, []);
98122

99-
const context = useMemo<ListContext>(() => ({ openQueueItems, toggleQueueItem }), [openQueueItems, toggleQueueItem]);
123+
const context = useMemo<ListContext>(
124+
() => ({ queryArgs, openQueueItems, toggleQueueItem }),
125+
[queryArgs, openQueueItems, toggleQueueItem]
126+
);
100127

101128
if (isLoading) {
102129
return <IAINoContentFallbackWithSpinner />;
103130
}
104131

105-
if (!queueItems.length) {
132+
if (!itemIds.length) {
106133
return (
107134
<Flex w="full" h="full" alignItems="center" justifyContent="center">
108135
<Heading color="base.500">{t('queue.queueEmpty')}</Heading>
@@ -114,18 +141,18 @@ const QueueList = () => {
114141
<Flex w="full" h="full" flexDir="column">
115142
<QueueListHeader />
116143
<Flex ref={rootRef} w="full" h="full" alignItems="center" justifyContent="center">
117-
<Virtuoso<S['SessionQueueItem'], ListContext>
118-
data={queueItems}
119-
endReached={handleLoadMore}
120-
scrollerRef={setScroller as TableVirtuosoScrollerRef}
144+
<Virtuoso<number, ListContext>
145+
ref={virtuosoRef}
146+
context={context}
147+
data={itemIds}
121148
itemContent={itemContent}
122149
computeItemKey={computeItemKey}
123150
components={components}
124-
context={context}
151+
scrollerRef={scrollerRef}
152+
scrollSeekConfiguration={scrollSeekConfiguration}
153+
rangeChanged={handleRangeChanged}
125154
/>
126155
</Flex>
127156
</Flex>
128157
);
129158
};
130-
131-
export default memo(QueueList);

invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,30 +33,27 @@ const QueueListHeader = () => {
3333
<QueueListHeaderColumn
3434
field="completed_at"
3535
displayName={t('queue.completedAt')}
36-
ps={0.5} w={COLUMN_WIDTHS.completedAt} alignItems="center"
37-
/>
38-
<QueueListHeaderColumn
39-
displayName={t('queue.origin')}
40-
ps={0.5} w={COLUMN_WIDTHS.origin} alignItems="center"
36+
ps={0.5}
37+
w={COLUMN_WIDTHS.completedAt}
38+
alignItems="center"
4139
/>
40+
<QueueListHeaderColumn displayName={t('queue.origin')} ps={0.5} w={COLUMN_WIDTHS.origin} alignItems="center" />
4241
<QueueListHeaderColumn
4342
displayName={t('queue.destination')}
44-
ps={0.5} w={COLUMN_WIDTHS.destination} alignItems="center"
45-
/>
46-
<QueueListHeaderColumn
47-
displayName={t('queue.time')}
48-
ps={0.5} w={COLUMN_WIDTHS.time} alignItems="center"
43+
ps={0.5}
44+
w={COLUMN_WIDTHS.destination}
45+
alignItems="center"
4946
/>
47+
<QueueListHeaderColumn displayName={t('queue.time')} ps={0.5} w={COLUMN_WIDTHS.time} alignItems="center" />
5048
{shouldShowCredits && (
5149
<QueueListHeaderColumn
5250
displayName={t('queue.credits')}
53-
ps={0.5} w={COLUMN_WIDTHS.credits} alignItems="center"
51+
ps={0.5}
52+
w={COLUMN_WIDTHS.credits}
53+
alignItems="center"
5454
/>
5555
)}
56-
<QueueListHeaderColumn
57-
displayName={t('queue.batch')}
58-
ps={0.5} w={COLUMN_WIDTHS.batchId} alignItems="center"
59-
/>
56+
<QueueListHeaderColumn displayName={t('queue.batch')} ps={0.5} w={COLUMN_WIDTHS.batchId} alignItems="center" />
6057
</Flex>
6158
);
6259
};

0 commit comments

Comments
 (0)