Skip to content
Merged
106 changes: 94 additions & 12 deletions invokeai/app/services/invocation_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
writes to the system log is stored in InvocationServices.performance_statistics.
"""

import psutil
import time
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
Expand All @@ -42,6 +43,11 @@
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
from .model_manager_service import ModelManagerService
from invokeai.backend.model_management.model_cache import CacheStats

# size of GIG in bytes
GIG = 1073741824


class InvocationStatsServiceBase(ABC):
Expand Down Expand Up @@ -89,6 +95,8 @@ def update_invocation_stats(
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
"""
Add timing information on execution of a node. Usually
Expand All @@ -97,6 +105,8 @@ def update_invocation_stats(
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
pass

Expand All @@ -115,6 +125,9 @@ class NodeStats:
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0


@dataclass
Expand All @@ -133,31 +146,62 @@ def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"
self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
self._cache_stats: Dict[str, CacheStats] = {}
self.ram_used: float = 0.0
self.ram_changed: float = 0.0

class StatsContext:
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
"""Context manager for collecting statistics."""

invocation: BaseInvocation = None
collector: "InvocationStatsServiceBase" = None
graph_id: str = None
start_time: int = 0
ram_used: int = 0
model_manager: ModelManagerService = None

def __init__(
self,
invocation: BaseInvocation,
graph_id: str,
model_manager: ModelManagerService,
collector: "InvocationStatsServiceBase",
):
"""Initialize statistics for this run."""
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
self.ram_used = 0
self.model_manager = model_manager

def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.ram_used = psutil.Process().memory_info().rss
if self.model_manager:
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])

def __exit__(self, *args):
"""Called on exit from the context."""
ram_used = psutil.Process().memory_info().rss
self.collector.update_mem_stats(
ram_used=ram_used / GIG,
ram_changed=(ram_used - self.ram_used) / GIG,
)
self.collector.update_invocation_stats(
self.graph_id,
self.invocation.type,
time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
graph_id=self.graph_id,
invocation_type=self.invocation.type,
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
)

def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
model_manager: ModelManagerService,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
Expand All @@ -166,7 +210,8 @@ def collect_stats(
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
return self.StatsContext(invocation, graph_execution_state_id, self)
self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)

def reset_all_stats(self):
"""Zero all statistics"""
Expand All @@ -179,13 +224,36 @@ def reset_stats(self, graph_execution_id: str):
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")

def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.

:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
self.ram_used = ram_used
self.ram_changed = ram_changed

def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Floating point seconds used by node's exection
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
Expand All @@ -197,7 +265,7 @@ def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used
def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed if when the execution of the graph
Stats will only be printed when the execution of the graph
is complete.
"""
completed = set()
Expand All @@ -208,16 +276,30 @@ def log_stats(self):

total_time = 0
logger.info(f"Graph stats: {graph_id}")
logger.info("Node Calls Seconds VRAM Used")
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
total_time += stats.time_used

cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG

logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
logger.info(f"RAM used to load models: {loaded:4.2f}G")
if torch.cuda.is_available():
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
logger.info("RAM cache statistics:")
logger.info(f" Model cache hits: {cache_stats.hits}")
logger.info(f" Model cache misses: {cache_stats.misses}")
logger.info(f" Models cached: {cache_stats.in_cache}")
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")

completed.add(graph_id)

for graph_id in completed:
del self._stats[graph_id]
del self._cache_stats[graph_id]
14 changes: 14 additions & 0 deletions invokeai/app/services/model_manager_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ModelNotFoundException,
)
from invokeai.backend.model_management.model_search import FindModels
from invokeai.backend.model_management.model_cache import CacheStats

import torch
from invokeai.app.models.exceptions import CanceledException
Expand Down Expand Up @@ -276,6 +277,13 @@ def sync_to_config(self):
"""
pass

@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
pass

@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
Expand Down Expand Up @@ -500,6 +508,12 @@ def convert_model(
self.logger.debug(f"convert model {model_name}")
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)

def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
self.mgr.cache.stats = cache_stats

def commit(self, conf_file: Optional[Path] = None):
"""
Write current configuration out to the indicated file.
Expand Down
4 changes: 3 additions & 1 deletion invokeai/app/services/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def __process(self, stop_event: Event):

# Invoke
try:
with statistics.collect_stats(invocation, graph_execution_state.id):
graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager
with statistics.collect_stats(invocation, graph_id, model_manager):
# use the internal invoke_internal(), which wraps the node's invoke() method in
# this accomodates nodes which require a value, but get it only from a
# connection
Expand Down
44 changes: 38 additions & 6 deletions invokeai/backend/model_management/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
import sys
import hashlib
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Union, types, Optional, Type, Any

import torch

import logging
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, SubModelType, ModelBase

Expand All @@ -41,6 +41,18 @@
GIG = 1073741824


@dataclass
class CacheStats(object):
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)


class ModelLocker(object):
"Forward declaration"
pass
Expand Down Expand Up @@ -115,6 +127,9 @@ def __init__(
self.sha_chunksize = sha_chunksize
self.logger = logger

# used for stats collection
self.stats = None

self._cached_models = dict()
self._cache_stack = list()

Expand Down Expand Up @@ -181,13 +196,14 @@ def get_model(
model_type=model_type,
submodel_type=submodel,
)

# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
)
if self.stats:
self.stats.misses += 1

# this will remove older cached models until
# there is sufficient room to load the requested model
Expand All @@ -201,6 +217,17 @@ def get_model(

cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry
else:
if self.stats:
self.stats.hits += 1

if self.stats:
self.stats.cache_size = self.max_cache_size * GIG
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[key] = max(
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
)

with suppress(Exception):
self._cache_stack.remove(key)
Expand Down Expand Up @@ -280,14 +307,14 @@ def model_hash(
"""
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs

:param model_path: Path to model file/directory on disk.
"""
return self._local_model_hash(model_path)

def cache_size(self) -> float:
"Return the current size of the cache, in GB"
current_cache_size = sum([m.size for m in self._cached_models.values()])
return current_cache_size / GIG
"""Return the current size of the cache, in GB."""
return self._cache_size() / GIG

def _has_cuda(self) -> bool:
return self.execution_device.type == "cuda"
Expand All @@ -310,12 +337,15 @@ def _print_cuda_stats(self):
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
)

def _cache_size(self) -> int:
return sum([m.size for m in self._cached_models.values()])

def _make_cache_room(self, model_size):
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_models.values()])
current_size = self._cache_size()

if current_size + bytes_needed > maximum_size:
self.logger.debug(
Expand Down Expand Up @@ -364,6 +394,8 @@ def _make_cache_room(self, model_size):
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
if self.stats:
self.stats.cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry
Expand Down