Skip to content

Commit cf61bf7

Browse files
authored
feat(llm): add progress bar when ollama is pulling models (#2031)
* fix: add ollama progress bar when pulling models * feat: add ollama queue * fix: mypy
1 parent 50b3027 commit cf61bf7

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

private_gpt/utils/ollama.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
import logging
2+
from collections import deque
3+
from collections.abc import Iterator, Mapping
4+
from typing import Any
5+
6+
from tqdm import tqdm # type: ignore
27

38
try:
49
from ollama import Client # type: ignore
@@ -19,12 +24,55 @@ def check_connection(client: Client) -> bool:
1924
return False
2025

2126

27+
def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None:
28+
progress_bars = {}
29+
queue = deque() # type: ignore
30+
31+
def create_progress_bar(dgt: str, total: int) -> Any:
32+
return tqdm(
33+
total=total, desc=f"Pulling model {dgt[7:17]}...", unit="B", unit_scale=True
34+
)
35+
36+
current_digest = None
37+
38+
for chunk in generator:
39+
digest = chunk.get("digest")
40+
completed_size = chunk.get("completed", 0)
41+
total_size = chunk.get("total")
42+
43+
if digest and total_size is not None:
44+
if digest not in progress_bars and completed_size > 0:
45+
progress_bars[digest] = create_progress_bar(digest, total=total_size)
46+
if current_digest is None:
47+
current_digest = digest
48+
else:
49+
queue.append(digest)
50+
51+
if digest in progress_bars:
52+
progress_bar = progress_bars[digest]
53+
progress = completed_size - progress_bar.n
54+
if completed_size > 0 and total_size >= progress != progress_bar.n:
55+
if digest == current_digest:
56+
progress_bar.update(progress)
57+
if progress_bar.n >= total_size:
58+
progress_bar.close()
59+
current_digest = queue.popleft() if queue else None
60+
else:
61+
# Store progress for later update
62+
progress_bars[digest].total = total_size
63+
progress_bars[digest].n = completed_size
64+
65+
# Close any remaining progress bars at the end
66+
for progress_bar in progress_bars.values():
67+
progress_bar.close()
68+
69+
2270
def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None:
2371
try:
2472
installed_models = [model["name"] for model in client.list().get("models", {})]
2573
if model_name not in installed_models:
2674
logger.info(f"Pulling model {model_name}. Please wait...")
27-
client.pull(model_name)
75+
process_streaming(client.pull(model_name, stream=True))
2876
logger.info(f"Model {model_name} pulled successfully")
2977
except Exception as e:
3078
logger.error(f"Failed to pull model {model_name}: {e!s}")

0 commit comments

Comments
 (0)