Skip to content

Commit 9091e19

Browse files
authored
Add execution stat reporting after each invocation (#4125)
## What type of PR is this? (check all applicable) - [X] Feature ## Have you discussed this change with the InvokeAI team? - [X] Yes - [ ] No, because: ## Have you updated all relevant documentation? - [X] Yes - [ ] No ## Description This PR adds execution time and VRAM usage reporting to each graph invocation. The log output will look like this: ``` [2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3 [2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used [2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G ``` On systems without CUDA, the VRAM stats are not printed. The current implementation keeps track of graph ids separately so will not be confused when several graphs are executing in parallel. It handles exceptions, and it is integrated into the app framework by defining an abstract base class and storing an implementation instance in `InvocationServices`.
2 parents bf94412 + 0a0b714 commit 9091e19

File tree

8 files changed

+276
-33
lines changed

8 files changed

+276
-33
lines changed

invokeai/app/api/dependencies.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from typing import Optional
44
from logging import Logger
5-
import os
65
from invokeai.app.services.board_image_record_storage import (
76
SqliteBoardImageRecordStorage,
87
)
@@ -30,6 +29,7 @@
3029
from ..services.processor import DefaultInvocationProcessor
3130
from ..services.sqlite import SqliteItemStorage
3231
from ..services.model_manager_service import ModelManagerService
32+
from ..services.invocation_stats import InvocationStatsService
3333
from .events import FastAPIEventService
3434

3535

@@ -128,6 +128,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger
128128
graph_execution_manager=graph_execution_manager,
129129
processor=DefaultInvocationProcessor(),
130130
configuration=config,
131+
performance_statistics=InvocationStatsService(graph_execution_manager),
131132
logger=logger,
132133
)
133134

invokeai/app/cli_app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from invokeai.app.services.images import ImageService, ImageServiceDependencies
3838
from invokeai.app.services.resource_name import SimpleNameService
3939
from invokeai.app.services.urls import LocalUrlService
40+
from invokeai.app.services.invocation_stats import InvocationStatsService
4041
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
4142
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
4243

@@ -311,6 +312,7 @@ def invoke_cli():
311312
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
312313
graph_execution_manager=graph_execution_manager,
313314
processor=DefaultInvocationProcessor(),
315+
performance_statistics=InvocationStatsService(graph_execution_manager),
314316
logger=logger,
315317
configuration=config,
316318
)

invokeai/app/services/invocation_services.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class InvocationServices:
3232
logger: "Logger"
3333
model_manager: "ModelManagerServiceBase"
3434
processor: "InvocationProcessorABC"
35+
performance_statistics: "InvocationStatsServiceBase"
3536
queue: "InvocationQueueABC"
3637

3738
def __init__(
@@ -47,6 +48,7 @@ def __init__(
4748
logger: "Logger",
4849
model_manager: "ModelManagerServiceBase",
4950
processor: "InvocationProcessorABC",
51+
performance_statistics: "InvocationStatsServiceBase",
5052
queue: "InvocationQueueABC",
5153
):
5254
self.board_images = board_images
@@ -61,4 +63,5 @@ def __init__(
6163
self.logger = logger
6264
self.model_manager = model_manager
6365
self.processor = processor
66+
self.performance_statistics = performance_statistics
6467
self.queue = queue
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
2+
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
3+
4+
"""
5+
Usage:
6+
7+
statistics = InvocationStatsService(graph_execution_manager)
8+
with statistics.collect_stats(invocation, graph_execution_state.id):
9+
... execute graphs...
10+
statistics.log_stats()
11+
12+
Typical output:
13+
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
14+
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
15+
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
16+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
17+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
18+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
19+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
20+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
21+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
22+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
23+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
24+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
25+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
26+
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
27+
28+
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
29+
writes to the system log is stored in InvocationServices.performance_statistics.
30+
"""
31+
32+
import time
33+
from abc import ABC, abstractmethod
34+
from contextlib import AbstractContextManager
35+
from dataclasses import dataclass, field
36+
from typing import Dict
37+
38+
import torch
39+
40+
import invokeai.backend.util.logging as logger
41+
42+
from ..invocations.baseinvocation import BaseInvocation
43+
from .graph import GraphExecutionState
44+
from .item_storage import ItemStorageABC
45+
46+
47+
class InvocationStatsServiceBase(ABC):
48+
"Abstract base class for recording node memory/time performance statistics"
49+
50+
@abstractmethod
51+
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
52+
"""
53+
Initialize the InvocationStatsService and reset counters to zero
54+
:param graph_execution_manager: Graph execution manager for this session
55+
"""
56+
pass
57+
58+
@abstractmethod
59+
def collect_stats(
60+
self,
61+
invocation: BaseInvocation,
62+
graph_execution_state_id: str,
63+
) -> AbstractContextManager:
64+
"""
65+
Return a context object that will capture the statistics on the execution
66+
of invocaation. Use with: to place around the part of the code that executes the invocation.
67+
:param invocation: BaseInvocation object from the current graph.
68+
:param graph_execution_state: GraphExecutionState object from the current session.
69+
"""
70+
pass
71+
72+
@abstractmethod
73+
def reset_stats(self, graph_execution_state_id: str):
74+
"""
75+
Reset all statistics for the indicated graph
76+
:param graph_execution_state_id
77+
"""
78+
pass
79+
80+
@abstractmethod
81+
def reset_all_stats(self):
82+
"""Zero all statistics"""
83+
pass
84+
85+
@abstractmethod
86+
def update_invocation_stats(
87+
self,
88+
graph_id: str,
89+
invocation_type: str,
90+
time_used: float,
91+
vram_used: float,
92+
):
93+
"""
94+
Add timing information on execution of a node. Usually
95+
used internally.
96+
:param graph_id: ID of the graph that is currently executing
97+
:param invocation_type: String literal type of the node
98+
:param time_used: Time used by node's exection (sec)
99+
:param vram_used: Maximum VRAM used during exection (GB)
100+
"""
101+
pass
102+
103+
@abstractmethod
104+
def log_stats(self):
105+
"""
106+
Write out the accumulated statistics to the log or somewhere else.
107+
"""
108+
pass
109+
110+
111+
@dataclass
112+
class NodeStats:
113+
"""Class for tracking execution stats of an invocation node"""
114+
115+
calls: int = 0
116+
time_used: float = 0.0 # seconds
117+
max_vram: float = 0.0 # GB
118+
119+
120+
@dataclass
121+
class NodeLog:
122+
"""Class for tracking node usage"""
123+
124+
# {node_type => NodeStats}
125+
nodes: Dict[str, NodeStats] = field(default_factory=dict)
126+
127+
128+
class InvocationStatsService(InvocationStatsServiceBase):
129+
"""Accumulate performance information about a running graph. Collects time spent in each node,
130+
as well as the maximum and current VRAM utilisation for CUDA systems"""
131+
132+
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
133+
self.graph_execution_manager = graph_execution_manager
134+
# {graph_id => NodeLog}
135+
self._stats: Dict[str, NodeLog] = {}
136+
137+
class StatsContext:
138+
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
139+
self.invocation = invocation
140+
self.collector = collector
141+
self.graph_id = graph_id
142+
self.start_time = 0
143+
144+
def __enter__(self):
145+
self.start_time = time.time()
146+
if torch.cuda.is_available():
147+
torch.cuda.reset_peak_memory_stats()
148+
149+
def __exit__(self, *args):
150+
self.collector.update_invocation_stats(
151+
self.graph_id,
152+
self.invocation.type,
153+
time.time() - self.start_time,
154+
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
155+
)
156+
157+
def collect_stats(
158+
self,
159+
invocation: BaseInvocation,
160+
graph_execution_state_id: str,
161+
) -> StatsContext:
162+
"""
163+
Return a context object that will capture the statistics.
164+
:param invocation: BaseInvocation object from the current graph.
165+
:param graph_execution_state: GraphExecutionState object from the current session.
166+
"""
167+
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
168+
self._stats[graph_execution_state_id] = NodeLog()
169+
return self.StatsContext(invocation, graph_execution_state_id, self)
170+
171+
def reset_all_stats(self):
172+
"""Zero all statistics"""
173+
self._stats = {}
174+
175+
def reset_stats(self, graph_execution_id: str):
176+
"""Zero the statistics for the indicated graph."""
177+
try:
178+
self._stats.pop(graph_execution_id)
179+
except KeyError:
180+
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
181+
182+
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
183+
"""
184+
Add timing information on execution of a node. Usually
185+
used internally.
186+
:param graph_id: ID of the graph that is currently executing
187+
:param invocation_type: String literal type of the node
188+
:param time_used: Floating point seconds used by node's exection
189+
"""
190+
if not self._stats[graph_id].nodes.get(invocation_type):
191+
self._stats[graph_id].nodes[invocation_type] = NodeStats()
192+
stats = self._stats[graph_id].nodes[invocation_type]
193+
stats.calls += 1
194+
stats.time_used += time_used
195+
stats.max_vram = max(stats.max_vram, vram_used)
196+
197+
def log_stats(self):
198+
"""
199+
Send the statistics to the system logger at the info level.
200+
Stats will only be printed if when the execution of the graph
201+
is complete.
202+
"""
203+
completed = set()
204+
for graph_id, node_log in self._stats.items():
205+
current_graph_state = self.graph_execution_manager.get(graph_id)
206+
if not current_graph_state.is_complete():
207+
continue
208+
209+
total_time = 0
210+
logger.info(f"Graph stats: {graph_id}")
211+
logger.info("Node Calls Seconds VRAM Used")
212+
for node_type, stats in self._stats[graph_id].nodes.items():
213+
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
214+
total_time += stats.time_used
215+
216+
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
217+
if torch.cuda.is_available():
218+
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
219+
220+
completed.add(graph_id)
221+
222+
for graph_id in completed:
223+
del self._stats[graph_id]

invokeai/app/services/processor.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import time
22
import traceback
3-
from threading import Event, Thread, BoundedSemaphore
3+
from threading import BoundedSemaphore, Event, Thread
4+
5+
import invokeai.backend.util.logging as logger
46

57
from ..invocations.baseinvocation import InvocationContext
8+
from ..models.exceptions import CanceledException
69
from .invocation_queue import InvocationQueueItem
10+
from .invocation_stats import InvocationStatsServiceBase
711
from .invoker import InvocationProcessorABC, Invoker
8-
from ..models.exceptions import CanceledException
9-
10-
import invokeai.backend.util.logging as logger
1112

1213

1314
class DefaultInvocationProcessor(InvocationProcessorABC):
@@ -35,6 +36,8 @@ def stop(self, *args, **kwargs) -> None:
3536
def __process(self, stop_event: Event):
3637
try:
3738
self.__threadLimit.acquire()
39+
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
40+
3841
while not stop_event.is_set():
3942
try:
4043
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
@@ -83,35 +86,38 @@ def __process(self, stop_event: Event):
8386

8487
# Invoke
8588
try:
86-
outputs = invocation.invoke(
87-
InvocationContext(
88-
services=self.__invoker.services,
89-
graph_execution_state_id=graph_execution_state.id,
89+
with statistics.collect_stats(invocation, graph_execution_state.id):
90+
outputs = invocation.invoke(
91+
InvocationContext(
92+
services=self.__invoker.services,
93+
graph_execution_state_id=graph_execution_state.id,
94+
)
9095
)
91-
)
9296

93-
# Check queue to see if this is canceled, and skip if so
94-
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
95-
continue
97+
# Check queue to see if this is canceled, and skip if so
98+
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
99+
continue
96100

97-
# Save outputs and history
98-
graph_execution_state.complete(invocation.id, outputs)
101+
# Save outputs and history
102+
graph_execution_state.complete(invocation.id, outputs)
99103

100-
# Save the state changes
101-
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
104+
# Save the state changes
105+
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
102106

103-
# Send complete event
104-
self.__invoker.services.events.emit_invocation_complete(
105-
graph_execution_state_id=graph_execution_state.id,
106-
node=invocation.dict(),
107-
source_node_id=source_node_id,
108-
result=outputs.dict(),
109-
)
107+
# Send complete event
108+
self.__invoker.services.events.emit_invocation_complete(
109+
graph_execution_state_id=graph_execution_state.id,
110+
node=invocation.dict(),
111+
source_node_id=source_node_id,
112+
result=outputs.dict(),
113+
)
114+
statistics.log_stats()
110115

111116
except KeyboardInterrupt:
112117
pass
113118

114119
except CanceledException:
120+
statistics.reset_stats(graph_execution_state.id)
115121
pass
116122

117123
except Exception as e:
@@ -133,7 +139,7 @@ def __process(self, stop_event: Event):
133139
error_type=e.__class__.__name__,
134140
error=error,
135141
)
136-
142+
statistics.reset_stats(graph_execution_state.id)
137143
pass
138144

139145
# Check queue to see if this is canceled, and skip if so

scripts/dream.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
33

44
import warnings
5-
from invokeai.frontend.CLI import invokeai_command_line_interface as main
65

76
warnings.warn(
87
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
98
DeprecationWarning,
109
)
11-
main()
10+
11+
from invokeai.app.cli_app import invoke_cli
12+
13+
invoke_cli()

0 commit comments

Comments
 (0)