1
1
import logging
2
+ from collections import deque
3
+ from collections .abc import Iterator , Mapping
4
+ from typing import Any
5
+
6
+ from tqdm import tqdm # type: ignore
2
7
3
8
try :
4
9
from ollama import Client # type: ignore
@@ -19,12 +24,55 @@ def check_connection(client: Client) -> bool:
19
24
return False
20
25
21
26
27
+ def process_streaming (generator : Iterator [Mapping [str , Any ]]) -> None :
28
+ progress_bars = {}
29
+ queue = deque () # type: ignore
30
+
31
+ def create_progress_bar (dgt : str , total : int ) -> Any :
32
+ return tqdm (
33
+ total = total , desc = f"Pulling model { dgt [7 :17 ]} ..." , unit = "B" , unit_scale = True
34
+ )
35
+
36
+ current_digest = None
37
+
38
+ for chunk in generator :
39
+ digest = chunk .get ("digest" )
40
+ completed_size = chunk .get ("completed" , 0 )
41
+ total_size = chunk .get ("total" )
42
+
43
+ if digest and total_size is not None :
44
+ if digest not in progress_bars and completed_size > 0 :
45
+ progress_bars [digest ] = create_progress_bar (digest , total = total_size )
46
+ if current_digest is None :
47
+ current_digest = digest
48
+ else :
49
+ queue .append (digest )
50
+
51
+ if digest in progress_bars :
52
+ progress_bar = progress_bars [digest ]
53
+ progress = completed_size - progress_bar .n
54
+ if completed_size > 0 and total_size >= progress != progress_bar .n :
55
+ if digest == current_digest :
56
+ progress_bar .update (progress )
57
+ if progress_bar .n >= total_size :
58
+ progress_bar .close ()
59
+ current_digest = queue .popleft () if queue else None
60
+ else :
61
+ # Store progress for later update
62
+ progress_bars [digest ].total = total_size
63
+ progress_bars [digest ].n = completed_size
64
+
65
+ # Close any remaining progress bars at the end
66
+ for progress_bar in progress_bars .values ():
67
+ progress_bar .close ()
68
+
69
+
22
70
def pull_model (client : Client , model_name : str , raise_error : bool = True ) -> None :
23
71
try :
24
72
installed_models = [model ["name" ] for model in client .list ().get ("models" , {})]
25
73
if model_name not in installed_models :
26
74
logger .info (f"Pulling model { model_name } . Please wait..." )
27
- client .pull (model_name )
75
+ process_streaming ( client .pull (model_name , stream = True ) )
28
76
logger .info (f"Model { model_name } pulled successfully" )
29
77
except Exception as e :
30
78
logger .error (f"Failed to pull model { model_name } : { e !s} " )
0 commit comments