Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions batchata/core/batch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(self, config: BatchParams, jobs: List[Job]):
# Threading primitives
self._state_lock = threading.Lock()
self._shutdown_event = threading.Event()
self._progress_lock = threading.Lock()
self._last_progress_update = 0.0

# Batch tracking for progress display
self.batch_tracking: Dict[str, Dict] = {} # batch_id -> batch_info
Expand Down Expand Up @@ -231,9 +233,12 @@ def signal_handler(signum, frame):

# Call initial progress
if self._progress_callback:
stats = self.status()
batch_data = dict(self.batch_tracking)
self._progress_callback(stats, 0.0, batch_data)
with self._progress_lock:
with self._state_lock:
stats = self.status()
batch_data = dict(self.batch_tracking)
self._progress_callback(stats, 0.0, batch_data)
self._last_progress_update = time.time()

# Process all jobs synchronously
self._process_all_jobs()
Expand Down Expand Up @@ -495,14 +500,22 @@ def _poll_batch_status(self, provider, batch_id: str) -> Tuple[str, Optional[Dic
status, error_details = provider.get_batch_status(batch_id)

if self._progress_callback:
with self._state_lock:
stats = self.status()
elapsed_time = (datetime.now() - self._start_time).total_seconds()
batch_data = dict(self.batch_tracking)
self._progress_callback(stats, elapsed_time, batch_data)
# Rate limit progress updates and synchronize calls to prevent duplicate printing
current_time = time.time()
should_update = current_time - self._last_progress_update >= self._progress_interval

if should_update:
with self._progress_lock:
# Double-check timing inside the lock to avoid race condition
if current_time - self._last_progress_update >= self._progress_interval:
with self._state_lock:
stats = self.status()
elapsed_time = (datetime.now() - self._start_time).total_seconds()
batch_data = dict(self.batch_tracking)
self._progress_callback(stats, elapsed_time, batch_data)
self._last_progress_update = current_time

elapsed_seconds = poll_count * provider_polling_interval
logger.info(f"Batch {batch_id} status: {status} (polling for {elapsed_seconds:.1f}s)")

return status, error_details

Expand Down
53 changes: 30 additions & 23 deletions batchata/providers/openai/openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,29 +272,36 @@ def estimate_cost(self, jobs: List[Job]) -> float:

for job in jobs:
try:
# Prepare messages to get actual input
from .message_prepare import prepare_messages
messages, response_format = prepare_messages(job)

# Build full text for token estimation
full_text = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if isinstance(content, list):
# Handle multipart content (images, etc.)
for part in content:
if part.get("type") == "text":
full_text += f"{role}: {part.get('text', '')}\\n\\n"
else:
full_text += f"{role}: {content}\\n\\n"

# Add response format to token count if structured output
if response_format:
full_text += json.dumps(response_format)

# Estimate tokens
input_tokens = token_count_simple(full_text)
# Handle PDF files specially with accurate token estimation
if job.file and job.file.suffix.lower() == '.pdf':
from ...utils.pdf import estimate_pdf_tokens
# OpenAI: 300-1,280 tokens/page, use 1000 as reasonable average
input_tokens = estimate_pdf_tokens(job.file, job.prompt, tokens_per_page=1000)
logger.debug(f"Job {job.id}: Estimated {input_tokens} tokens for PDF")
else:
# Prepare messages to get actual input
from .message_prepare import prepare_messages
messages, response_format = prepare_messages(job)

# Build full text for token estimation
full_text = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if isinstance(content, list):
# Handle multipart content (images, etc.)
for part in content:
if part.get("type") == "text":
full_text += f"{role}: {part.get('text', '')}\\n\\n"
else:
full_text += f"{role}: {content}\\n\\n"

# Add response format to token count if structured output
if response_format:
full_text += json.dumps(response_format)

# Estimate tokens
input_tokens = token_count_simple(full_text)

# Calculate costs using tokencost
input_cost = float(calculate_cost_by_tokens(
Expand Down
59 changes: 33 additions & 26 deletions batchata/utils/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,39 +217,46 @@ def estimate_pdf_tokens(path: str | Path, prompt: Optional[str] = None,
"""
Estimate token count for a PDF file.

This is a generic utility that can be used by any provider to estimate
tokens for PDF processing.
Provider-specific tokens per page estimates:
- Anthropic: 1,500-3,000 tokens/page (default: 2000)
- Gemini: ~258 tokens/page
- OpenAI: 300-1,280 tokens/page (use: 1000)

Args:
path: Path to the PDF file
prompt: Optional prompt to include in token count
pdf_token_multiplier: Coefficient to apply to extracted text tokens
to account for PDF processing overhead (default: 1.5)
tokens_per_page: Estimated tokens per page for image-based PDFs (default: 2000)
pdf_token_multiplier: Deprecated, kept for compatibility
tokens_per_page: Tokens per page estimate (default: 2000 for Anthropic)

Returns:
Estimated token count
"""
from .llm import token_count_simple

page_count, is_textual, extracted_text = get_pdf_info(path)

if is_textual and extracted_text:
# Count tokens from extracted text
base_tokens = token_count_simple(extracted_text)
if prompt:
base_tokens += token_count_simple(prompt)

# Apply multiplier to account for PDF processing overhead
input_tokens = int(base_tokens * pdf_token_multiplier)
logger.debug(f"Textual PDF {path}: {page_count} pages, "
f"base tokens: {base_tokens}, with {pdf_token_multiplier}x multiplier: {input_tokens}")
else:
# Estimate based on page count
input_tokens = page_count * tokens_per_page
if prompt:
input_tokens += token_count_simple(prompt)
logger.debug(f"Image-based PDF {path}: {page_count} pages, "
f"estimated tokens: {input_tokens} ({tokens_per_page} per page)")

return input_tokens
try:
# Get page count
reader = pypdf.PdfReader(str(path))
page_count = len(reader.pages)

# Use provider-specific tokens per page estimate
pdf_tokens = page_count * tokens_per_page

# Add prompt tokens
prompt_tokens = token_count_simple(prompt) if prompt else 0

# Add minimal overhead for PDF processing
PDF_TOKEN_OVERHEAD = 100 # tokens
overhead_tokens = PDF_TOKEN_OVERHEAD

total_tokens = pdf_tokens + prompt_tokens + overhead_tokens

logger.debug(
f"PDF {path}: {page_count} pages × {tokens_per_page} = {pdf_tokens} tokens, "
f"prompt: {prompt_tokens}, total: {total_tokens}"
)

return total_tokens

except Exception as e:
logger.warning(f"Failed to estimate PDF tokens: {e}")
return 0
138 changes: 102 additions & 36 deletions batchata/utils/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from rich.tree import Tree
from rich.text import Text

# Constants
PROGRESS_BAR_WIDTH = 25


class RichBatchProgressDisplay:
"""Rich-based progress display for batch runs."""
Expand Down Expand Up @@ -48,7 +51,7 @@ def start(self, stats: Dict, config: Dict):
self._create_display(),
console=self.console,
refresh_per_second=4, # Reduced refresh rate to avoid flicker
auto_refresh=True
auto_refresh=False # Disable auto-refresh to prevent race conditions with manual updates
)
self.live.start()

Expand All @@ -68,9 +71,11 @@ def update(self, stats: Dict, batch_data: Dict, elapsed_time: float):
# Advance spinner
self._spinner_index = (self._spinner_index + 1) % len(self._spinner_frames)

# Update live display
# Update live display (synchronized to prevent race conditions)
if self.live:
self.live.update(self._create_display())
# Force refresh since auto_refresh is disabled
self.live.refresh()

def stop(self):
"""Stop the live progress display."""
Expand Down Expand Up @@ -138,37 +143,17 @@ def _create_display(self) -> Tree:
is_last = idx == num_batches - 1
tree_symbol = "└─" if is_last else "├─"

# Format progress bar with better styling
progress_pct = (completed / total) if total > 0 else 0
filled_width = int(progress_pct * 25)
# Extract job counts
failed_count = batch_info.get('failed', 0)
success_count = completed
total_processed = success_count + failed_count
progress_pct = (total_processed / total) if total > 0 else 0

if status == 'complete':
bar = "[bold green]" + "━" * 25 + "[/bold green]"
elif status == 'failed':
bar = "[bold red]" + "━" * 25 + "[/bold red]"
elif status == 'cancelled':
bar = "[bold yellow]" + "━" * filled_width + "[/bold yellow]"
if filled_width < 25:
bar += "[dim yellow]" + "━" * (25 - filled_width) + "[/dim yellow]"
elif status == 'running':
bar = "[bold blue]" + "━" * filled_width + "[/bold blue]"
if filled_width < 25:
bar += "[blue]╸[/blue]" + "[dim white]" + "━" * (24 - filled_width) + "[/dim white]"
else:
bar = "[dim white]" + "━" * 25 + "[/dim white]"
# Create progress bar based on status
bar = self._create_progress_bar(status, success_count, failed_count, total, progress_pct)

# Format status with better colors and fixed width
if status == 'complete':
status_text = "[bold green]Ended[/bold green] "
elif status == 'failed':
status_text = "[bold red]Failed[/bold red] "
elif status == 'cancelled':
status_text = "[bold yellow]Cancelled[/bold yellow]"
elif status == 'running':
spinner = self._spinner_frames[self._spinner_index]
status_text = f"[bold blue]{spinner} Running[/bold blue]"
else:
status_text = "[dim]Pending[/dim]"
# Format status text
status_text = self._format_status_text(status, failed_count)

# Calculate elapsed time
start_time = batch_info.get('start_time')
Expand Down Expand Up @@ -196,7 +181,7 @@ def _create_display(self) -> Tree:
else:
time_str = "-:--:--"

# Format percentage
# Format percentage based on total processed (successful + failed)
percentage = int(progress_pct * 100)

# Get output filenames if completed
Expand Down Expand Up @@ -226,11 +211,11 @@ def _create_display(self) -> Tree:
else:
cost_text = f"${cost:>5.3f}"

# Create the batch line with proper spacing
# Create the batch line
display_stats = self._get_display_stats(status, success_count, failed_count, total)
batch_line = (
f"{provider} {batch_id:<18} {bar} "
f"{completed:>2}/{total:<2} {percentage:>3}% "
f"{status_text} "
f"{display_stats['completed']:>2}/{total:<2} ({display_stats['percentage']}% done) {status_text:<15} "
f"{cost_text} "
f"{time_str:>8}"
)
Expand Down Expand Up @@ -275,4 +260,85 @@ def _create_display(self) -> Tree:
footer = " │ ".join(footer_parts)
tree.add(f"\n[dim]{footer}[/dim]")

return tree
return tree

def _create_progress_bar(self, status: str, success_count: int, failed_count: int, total: int, progress_pct: float) -> str:
"""Create a progress bar showing success/failure proportions."""

if status == 'complete':
return f"[bold green]{'━' * PROGRESS_BAR_WIDTH}[/bold green]"

if status == 'failed':
return self._create_mixed_bar(success_count, failed_count, total, PROGRESS_BAR_WIDTH)

if status == 'cancelled':
filled = int(progress_pct * PROGRESS_BAR_WIDTH)
return f"[bold yellow]{'━' * filled}[/bold yellow][dim yellow]{'━' * (PROGRESS_BAR_WIDTH - filled)}[/dim yellow]"

if status == 'running':
filled = int(progress_pct * PROGRESS_BAR_WIDTH)
if filled < PROGRESS_BAR_WIDTH:
return f"[bold blue]{'━' * filled}[/bold blue][blue]╸[/blue][dim white]{'━' * (PROGRESS_BAR_WIDTH - filled - 1)}[/dim white]"
return f"[bold blue]{'━' * PROGRESS_BAR_WIDTH}[/bold blue]"

# Pending
return f"[dim white]{'━' * PROGRESS_BAR_WIDTH}[/dim white]"

def _create_mixed_bar(self, success_count: int, failed_count: int, total: int, bar_width: int) -> str:
"""Create a bar showing green (success) and red (failed) proportions."""
if total == 0:
return f"[dim white]{'━' * bar_width}[/dim white]"

# Use integer division to calculate base widths
success_width = (success_count * bar_width) // total
failed_width = (failed_count * bar_width) // total

# Distribute remainder to maintain exact bar_width
remainder = bar_width - success_width - failed_width
if remainder > 0:
# Distribute remainder based on which segment has larger fractional part
success_fraction = (success_count * bar_width) % total
failed_fraction = (failed_count * bar_width) % total

if success_fraction >= failed_fraction:
success_width += remainder
else:
failed_width += remainder

# Build the bar
bar_parts = []
if success_width > 0:
bar_parts.append(f"[bold green]{'━' * success_width}[/bold green]")
if failed_width > 0:
bar_parts.append(f"[bold red]{'━' * failed_width}[/bold red]")

return "".join(bar_parts)

def _format_status_text(self, status: str, failed_count: int) -> str:
"""Format the status text with appropriate colors and details."""
if status == 'complete':
return "[bold green]Complete[/bold green]"
elif status == 'failed':
if failed_count > 0:
return f"[bold red]Failed ({failed_count})[/bold red]"
return "[bold red]Failed[/bold red]"
elif status == 'cancelled':
return "[bold yellow]Cancelled[/bold yellow]"
elif status == 'running':
spinner = self._spinner_frames[self._spinner_index]
return f"[bold blue]{spinner} Running[/bold blue]"
else:
return "[dim]Pending[/dim]"

def _get_display_stats(self, status: str, success_count: int, failed_count: int, total: int) -> dict:
"""Get the display statistics (completed count and percentage)."""
if status == 'failed' and failed_count > 0:
# For failed batches, show success count to make it clear
completed = success_count
percentage = int((success_count / total) * 100) if total > 0 else 0
else:
# For other statuses, show total processed
completed = success_count + failed_count
percentage = int((completed / total) * 100) if total > 0 else 0

return {'completed': completed, 'percentage': percentage}
Loading
Loading