Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/curator/client #491

Merged
merged 12 commits into from
Feb 13, 2025
92 changes: 92 additions & 0 deletions src/bespokelabs/curator/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import json
import logging
import os
import typing as t
import uuid

import httpx
import requests

from bespokelabs.curator.constants import BASE_CLIENT_URL, PUBLIC_CURATOR_VIEWER_URL

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class _SessionStatus:
"""A class to represent the status of a session."""

STARTED = "STARTED"
INPROGRESS = "INPROGRESS"
COMPLETED = "COMPLETED"
FAILED = "FAILED"


class Client:
kartik4949 marked this conversation as resolved.
Show resolved Hide resolved
"""A class to represent the client for the Curator Viewer."""

def __init__(self) -> None:
"""Initialize the client."""
self._session = None
self._state = None
self._hosted = os.environ.get("HOSTED_CURATOR_VIEWER") in ["True", "true", "1", "t"]

@property
def session(self):
"""Get the session ID."""
return self._session

@property
def hosted(self):
"""Check if the client is hosted."""
return self._hosted

def create_session(self, metadata: t.Dict):
"""Sends a POST request to the server to create a session."""
if "HOSTED_CURATOR_VIEWER" not in os.environ:
logger.info("Set HOSTED_CURATOR_VIEWER=1 to view your data live at https://curator.bespokelabs.ai/datasets/.")
if not self.hosted:
return str(uuid.uuid4().hex)

if self.session:
return self.session
metadata.update({"status": _SessionStatus.STARTED})

response = requests.post(f"{BASE_CLIENT_URL}/sessions", json=metadata)

if response.status_code == 200:
self._session = response.json().get("session_id")
logger.info("View your data live at: " + f"{PUBLIC_CURATOR_VIEWER_URL}/{self.session}")
self._state = _SessionStatus.STARTED
return self.session
else:
logger.warning(f"Failed to create session: {response.status_code}, {response.text}")
return str(uuid.uuid4().hex)

async def _update_state(self):
async with httpx.AsyncClient() as client:
response = await client.put(f"{BASE_CLIENT_URL}/sessions/{self.session}", json={"status": self._state})
if response.status_code != 200:
logger.debug(f"Failed to update session status: {response.status_code}, {response.text}")

async def session_completed(self):
"""Updates the session status to completed."""
self._state = _SessionStatus.COMPLETED
if not self._hosted and not self.session:
return
await self._update_state()

async def stream_response(self, response_data: str, idx: int):
"""Streams the response data to the server."""
if not self._hosted and not self.session:
return
if self._state == _SessionStatus.STARTED:
self._state = _SessionStatus.INPROGRESS
await self._update_state()

response_data = json.dumps({"response_data": response_data})
async with httpx.AsyncClient() as client:
response = await client.post(f"{BASE_CLIENT_URL}/sessions/{self.session}/responses/{idx}", data=response_data)

if response.status_code != 200:
logger.debug(f"Failed to stream response to curator Viewer: {response.status_code}, {response.text}")
2 changes: 2 additions & 0 deletions src/bespokelabs/curator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
BATCH_REQUEST_ID_TAG = "custom_id"
_CURATOR_DEFAULT_CACHE_DIR = "~/.cache/curator"
_DEFAULT_CACHE_DIR = "~/.cache"
BASE_CLIENT_URL = "https://api.bespokelabs.ai/v0/viewer"
PUBLIC_CURATOR_VIEWER_URL = "https://curator.bespokelabs.ai/datasets"
4 changes: 4 additions & 0 deletions src/bespokelabs/curator/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import BaseModel
from xxhash import xxh64

from bespokelabs.curator.client import Client
from bespokelabs.curator.constants import _CURATOR_DEFAULT_CACHE_DIR
from bespokelabs.curator.db import MetadataDB
from bespokelabs.curator.llm.prompt_formatter import PromptFormatter
Expand Down Expand Up @@ -196,6 +197,7 @@ def __call__(

metadata_db_path = os.path.join(curator_cache_dir, "metadata.db")
metadata_db = MetadataDB(metadata_db_path)
self._request_processor.viewer_client = Client()

# Get the source code of the prompt function
prompt_func_source = _get_function_source(self.prompt_formatter.prompt_func)
Expand All @@ -214,6 +216,8 @@ def __call__(
"run_hash": fingerprint,
"batch_mode": self.batch_mode,
}
session_id = self._request_processor.viewer_client.create_session(metadata_dict)
metadata_dict["session_id"] = session_id
metadata_db.store_metadata(metadata_dict)

run_cache_dir = os.path.join(curator_cache_dir, fingerprint)
Expand Down
73 changes: 45 additions & 28 deletions src/bespokelabs/curator/request_processor/base_request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, config: RequestProcessorConfig):
desired_limit = min(10_000_000, hard)
logger.debug(f"Adjusting file descriptor limit from {soft} to {desired_limit} (hard limit: {hard})")
resource.setrlimit(resource.RLIMIT_NOFILE, (desired_limit, hard))
self._viewer_client = None
self.config = config
self._cost_processor = cost_processor_factory(self.backend)

Expand Down Expand Up @@ -81,6 +82,16 @@ def requests_to_responses(self, generic_request_files: list[str]) -> None:
"""
pass

@property
def viewer_client(self):
"""Return the viewer client for the request processor."""
return self._viewer_client

@viewer_client.setter
def viewer_client(self, client):
"""Set the viewer client for the request processor."""
self._viewer_client = client

def check_structured_output_support(self) -> bool:
"""Check if the model supports structured output.

Expand Down Expand Up @@ -325,6 +336,37 @@ def attempt_loading_cached_dataset(self, parse_func_hash: str) -> Optional["Data
"Deleted file and attempting to regenerate dataset from cached LLM responses."
)

def _process_response(self, data: GenericResponse) -> List | None:
try:
data.response_message = self.prompt_formatter.response_to_response_format(data.response_message)
except (json.JSONDecodeError, ValidationError):
logger.warning("Skipping response due to error parsing response message into response format")
return

# parse_func can return a single row or a list of rows
responses = None
if self.prompt_formatter.parse_func:
try:
responses = self.prompt_formatter.parse_func(
data.generic_request.original_row,
data.response_message,
)
except Exception as e:
logger.warning(f"Skipping response due to error in `parse_func` :: {e}")
return

if not isinstance(responses, list):
responses = [responses]
else:
# Convert response to dict before adding to dataset
response_value = data.response_message
if hasattr(response_value, "model_dump"):
response_value = response_value.model_dump()
elif hasattr(response_value, "__dict__"):
response_value = response_value.__dict__
responses = [{"response": response_value}]
return responses

def create_dataset_files(
self,
parse_func_hash: str,
Expand Down Expand Up @@ -369,35 +411,10 @@ def create_dataset_files(
if len(error_sample) < 10:
error_sample.append(str(response.response_errors))
continue

try:
response.response_message = self.prompt_formatter.response_to_response_format(response.response_message)
except (json.JSONDecodeError, ValidationError):
logger.warning("Skipping response due to error parsing response message into response format")
failed_responses_count += 1
continue

# parse_func can return a single row or a list of rows
if self.prompt_formatter.parse_func:
try:
dataset_rows = self.prompt_formatter.parse_func(
response.generic_request.original_row,
response.response_message,
)
except Exception as e:
logger.error(f"Exception raised in your `parse_func`. {error_help}")
os.remove(dataset_file)
raise e
if not isinstance(dataset_rows, list):
dataset_rows = [dataset_rows]
if response.parsed_response_message is None:
dataset_rows = self._process_response(response)
else:
# Convert response to dict before adding to dataset
response_value = response.response_message
if hasattr(response_value, "model_dump"):
response_value = response_value.model_dump()
elif hasattr(response_value, "__dict__"):
response_value = response_value.__dict__
dataset_rows = [{"response": response_value}]
dataset_rows = response.parsed_response_message

for row in dataset_rows:
if isinstance(row, BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import datetime
import json
import logging
import os
import time
import typing as t
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -431,6 +432,7 @@ async def process_requests_from_file(
done, pending_retries = await asyncio.wait(pending_retries, timeout=0.1)

status_tracker.stop_tracker()
await self.viewer_client.session_completed()

# Log final status
logger.info(f"Processing complete. Results saved to {response_file}")
Expand Down Expand Up @@ -515,7 +517,7 @@ async def handle_single_request_with_retries(
created_at=request.created_at,
finished_at=datetime.datetime.now(),
)
await self.append_generic_response(generic_response, response_file)
await self.append_generic_response(status_tracker, generic_response, response_file)
status_tracker.num_tasks_in_progress -= 1
status_tracker.num_tasks_failed += 1
return
Expand All @@ -526,7 +528,7 @@ async def handle_single_request_with_retries(
self._semaphore.release()

# Save response in the base class
await self.append_generic_response(generic_response, response_file)
await self.append_generic_response(status_tracker, generic_response, response_file)

status_tracker.num_tasks_in_progress -= 1
status_tracker.num_tasks_succeeded += 1
Expand Down Expand Up @@ -559,14 +561,24 @@ async def call_single_request(
"""
pass

async def append_generic_response(self, data: GenericResponse, filename: str) -> None:
async def append_generic_response(self, status_tracker: OnlineStatusTracker, data: GenericResponse, filename: str) -> None:
"""Append a response to a jsonl file with async file operations.

Args:
status_tracker: Tracker containing request status
data: Response data to append
filename: File to append to
"""
json_string = json.dumps(data.model_dump(), default=str)
responses = self._process_response(data)
if not responses:
return
data.parsed_response_message = responses
data_dump = data.model_dump()
json_string = json.dumps(data_dump, default=str)
async with aiofiles.open(filename, "a") as f:
await f.write(json_string + "\n")
logger.debug(f"Successfully appended response to {filename}")
filename = os.path.basename(filename).split(".")[0]
idx = status_tracker.num_parsed_responses
status_tracker.num_parsed_responses = idx + len(responses)
await self.viewer_client.stream_response(json_string, idx)
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class OnlineStatusTracker:
num_api_errors: int = 0
num_other_errors: int = 0
num_rate_limit_errors: int = 0
num_parsed_responses: int = 0
available_request_capacity: float = 1.0
available_token_capacity: float | _TokenCount = 0
last_update_time: float = field(default_factory=time.time)
Expand Down
1 change: 1 addition & 0 deletions src/bespokelabs/curator/types/generic_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class GenericResponse(BaseModel):
"""A generic response model for LLM API requests."""

response_message: Optional[Dict[str, Any]] | str = None
parsed_response_message: Optional[list] = None
response_errors: Optional[List[str]] = None
raw_response: Optional[Dict[str, Any]]
raw_request: Optional[Dict[str, Any]] = None
Expand Down