Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SN1-348: Improve api stability #488

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__/
.DS_Store
**/.DS_Store

*.npy
*.npz
prompting/storage/

Expand Down
3 changes: 2 additions & 1 deletion prompting/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from prompting.api.api_managements.api import router as api_management_router
from prompting.api.gpt_endpoints.api import router as gpt_router
from prompting.api.miner_availabilities.api import router as miner_availabilities_router
from prompting.settings import settings

app = FastAPI()

Expand All @@ -25,4 +26,4 @@ def health():
# if __name__ == "__main__":
async def start_api():
logger.info("Starting API...")
uvicorn.run("prompting.api.api:app", host="0.0.0.0", port=8004, loop="asyncio", reload=False)
uvicorn.run("prompting.api.api:app", host="0.0.0.0", port=settings.API_PORT, loop="asyncio", reload=False)
150 changes: 117 additions & 33 deletions prompting/api/gpt_endpoints/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from fastapi import APIRouter, Depends, HTTPException, Request
from loguru import logger
from starlette.responses import StreamingResponse

from prompting.api.api_managements.api import validate_api_key
from prompting.base.dendrite import DendriteResponseEvent, SynapseStreamResult
from prompting.base.epistula import query_miners
from prompting.miner_availability.miner_availability import miner_availabilities
from prompting.rewards.scoring import task_scorer
from prompting.settings import settings
Expand Down Expand Up @@ -61,6 +61,80 @@ async def mixture_of_agents(request: Request, api_key_data: dict = Depends(valid
return {"message": "Mixture of Agents"}


import asyncio

from prompting.base.epistula import make_openai_query


async def query_endpoint(metagraph, wallet, body, uid, stream):
try:
response = await make_openai_query(metagraph=metagraph, wallet=wallet, body=body, uid=uid, stream=stream)
if stream:

async def stream_response():
collected_content = []
collected_chunks_timings = []
with Timer() as timer:
async for chunk in response:
if (
hasattr(chunk, "choices")
and chunk.choices
and isinstance(chunk.choices[0].delta.content, str)
):
content = chunk.choices[0].delta.content
collected_content.append(content)
collected_chunks_timings.append(timer.elapsed_time())
yield f"data: {json.dumps(chunk.model_dump())}\n\n"

# Add task to scoring queue after stream completes
task_obj = InferenceTask(
query=body["messages"][-1]["content"],
messages=[message["content"] for message in body["messages"]],
model=body.get("model"),
seed=body.get("seed"),
response="".join(collected_content),
)
response_event = DendriteResponseEvent(
stream_results=[
SynapseStreamResult(
uid=uid,
accumulated_chunks=collected_content,
accumulated_chunks_timings=collected_chunks_timings,
)
],
uids=[uid],
timeout=settings.NEURON_TIMEOUT,
completions=["".join(collected_content)],
)
task_scorer.add_to_queue(
task=task_obj,
response=response_event,
dataset_entry=task_obj.dataset_entry,
block=-1,
step=-1,
task_id=task_obj.task_id,
)
yield "data: [DONE]\n\n"

return StreamingResponse(stream_response(), media_type="text/event-stream")
else:
return response

except Exception as e:
logger.error(f"Error querying miner with uid {uid}: {e}")
return None


async def query_all_endpoints(metagraph, wallet, body, uids, stream):
tasks = [query_endpoint(metagraph, wallet, body, uid, stream) for uid in uids]
for task in asyncio.as_completed(tasks):
result = await task
if result is not None:
return result
raise HTTPException(status_code=503, detail="No valid response from any endpoint")


# TODO: Modify this so ALL of the responses are added for scoring rather than just the first one
@router.post("/v1/chat/completions")
async def proxy_chat_completions(request: Request, api_key_data: dict = Depends(validate_api_key)):
body = await request.json()
Expand All @@ -69,13 +143,16 @@ async def proxy_chat_completions(request: Request, api_key_data: dict = Depends(
raise HTTPException(status_code=400, detail=f"Task {body.get('task')} not found")
logger.debug(f"Requested Task: {body.get('task')}, {task}")

stream = body.get("stream")
stream = body.get("stream", False)
body = {k: v for k, v in body.items() if k not in ["task", "stream"]}
body["task"] = task.__class__.__name__
body["seed"] = body.get("seed") or str(random.randint(0, 1_000_000))
logger.debug(f"Seed provided by miner: {bool(body.get('seed'))} -- Using seed: {body.get('seed')}")

if settings.TEST_MINER_IDS:
# Get available miners
if uids := body.get("uids"):
available_miners = uids
elif settings.TEST_MINER_IDS:
available_miners = settings.TEST_MINER_IDS
elif not settings.mode == "mock" and not (
available_miners := miner_availabilities.get_available_miners(task=task, model=body.get("model"))
Expand All @@ -84,33 +161,40 @@ async def proxy_chat_completions(request: Request, api_key_data: dict = Depends(
status_code=503,
detail=f"No miners available for model: {body.get('model')} and task: {task.__class__.__name__}",
)

response = query_miners(available_miners, json.dumps(body).encode("utf-8"), stream=stream)
if stream:
return response
else:
response = await response
response_event = DendriteResponseEvent(
stream_results=response,
uids=available_miners,
timeout=settings.NEURON_TIMEOUT,
completions=["".join(res.accumulated_chunks) for res in response],
)

task = task(
query=body["messages"][-1]["content"],
messages=[message["content"] for message in body["messages"]],
model=body.get("model"),
seed=body.get("seed"),
response=response_event,
)

task_scorer.add_to_queue(
task=task,
response=response_event,
dataset_entry=task.dataset_entry,
block=-1,
step=-1,
task_id=task.task_id,
)
return [res.model_dump() for res in response]
random.shuffle(available_miners)

try:
result = await query_all_endpoints(settings.METAGRAPH, settings.WALLET, body, available_miners, stream)
if not stream:
# Handle non-streaming response scoring
response_event = DendriteResponseEvent(
stream_results=[
SynapseStreamResult(
uid=available_miners[0],
accumulated_chunks=[result.choices[0].message.content],
accumulated_chunks_timings=[0.0],
)
],
uids=[available_miners[0]],
timeout=settings.NEURON_TIMEOUT,
completions=[result.choices[0].message.content],
)
task_obj = InferenceTask(
query=body["messages"][-1]["content"],
messages=[message["content"] for message in body["messages"]],
model=body.get("model"),
seed=body.get("seed"),
response=result.choices[0].message.content,
)
task_scorer.add_to_queue(
task=task_obj,
response=response_event,
dataset_entry=task_obj.dataset_entry,
block=-1,
step=-1,
task_id=task_obj.task_id,
)
return result
except Exception as e:
logger.error(f"Failed to get a valid response: {e}")
raise HTTPException(status_code=503, detail=str(e))
91 changes: 61 additions & 30 deletions prompting/base/epistula.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from hashlib import sha256
from math import ceil
from typing import Annotated, Any, Dict, List, Optional
from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional
from uuid import uuid4

import bittensor as bt
Expand Down Expand Up @@ -79,7 +79,34 @@ async def add_headers(request: httpx.Request):
return add_headers


async def query_miners(uids: list = [], body: bytes = b"", stream: bool = False):
async def merged_stream(responses: list[AsyncGenerator]):
streams = [response.__aiter__() for response in responses if not isinstance(response, Exception)]
pending = {}
for stream in streams:
try:
task = asyncio.create_task(stream.__anext__())
pending[task] = stream
except StopAsyncIteration:
continue # Skip empty streams

while pending:
done, _ = await asyncio.wait(pending.keys(), return_when=asyncio.FIRST_COMPLETED)
for task in done:
stream = pending.pop(task)
try:
result = task.result()
yield result
# Schedule the next item from the same stream
next_task = asyncio.create_task(stream.__anext__())
pending[next_task] = stream
except StopAsyncIteration:
# Stream is exhausted
pass
except Exception as e:
logger.error(f"Error while streaming: {e}")


async def query_miners(uids: list = [], body: bytes = b"", stream: bool = False, return_first: bool = False):
try:
tasks = []
for uid in uids:
Expand All @@ -94,44 +121,20 @@ async def query_miners(uids: list = [], body: bytes = b"", stream: bool = False)
)
)
)
if return_first:
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
return [await done.pop()]
responses = await asyncio.gather(*tasks, return_exceptions=True)

# Filter out exceptions from responses
exceptions = [resp for resp in responses if isinstance(resp, Exception)]
if exceptions:
for exc in exceptions:
logger.error(f"Error in handle_inference: {exc}")
# Handle exceptions as needed

if stream:
# 'responses' is a list of async iterators (chat objects)
async def merged_stream():
streams = [response.__aiter__() for response in responses if not isinstance(response, Exception)]
pending = {}
for stream in streams:
try:
task = asyncio.create_task(stream.__anext__())
pending[task] = stream
except StopAsyncIteration:
continue # Skip empty streams

while pending:
done, _ = await asyncio.wait(pending.keys(), return_when=asyncio.FIRST_COMPLETED)
for task in done:
stream = pending.pop(task)
try:
result = task.result()
yield result
# Schedule the next item from the same stream
next_task = asyncio.create_task(stream.__anext__())
pending[next_task] = stream
except StopAsyncIteration:
# Stream is exhausted
pass
except Exception as e:
logger.error(f"Error while streaming: {e}")

return merged_stream()
return merged_stream(responses)
else:
# 'responses' is a list of SynapseStreamResult objects
return [resp for resp in responses if not isinstance(resp, Exception)]
Expand Down Expand Up @@ -185,6 +188,34 @@ async def handle_availability(
return {}


async def make_openai_query(
metagraph: "bt.NonTorchMetagraph",
wallet: "bt.wallet",
body: dict[str, Any],
uid: int,
stream: bool = False,
):
axon_info = metagraph.axons[uid]
miner = openai.AsyncOpenAI(
base_url=f"http://{axon_info.ip}:{axon_info.port}/v1",
api_key="Apex",
max_retries=0,
timeout=Timeout(10, connect=5, read=10),
http_client=openai.DefaultAsyncHttpxClient(
event_hooks={"request": [create_header_hook(wallet.hotkey, axon_info.hotkey)]}
),
)
# payload = json.loads(body)
payload = body
chat = await miner.chat.completions.create(
messages=payload["messages"],
model=payload["model"],
stream=stream,
extra_body={k: v for k, v in payload.items() if k not in ["messages", "model"]},
)
return chat


async def handle_inference(
metagraph: "bt.NonTorchMetagraph",
wallet: "bt.wallet",
Expand Down
1 change: 1 addition & 0 deletions prompting/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class Settings(BaseSettings):
SCORING_QUEUE_LENGTH_THRESHOLD: int = Field(10, env="SCORING_QUEUE_LENGTH_THRESHOLD")
HF_TOKEN: Optional[str] = Field(None, env="HF_TOKEN")
DEPLOY_API: bool = Field(False, env="DEPLOY_API")
API_PORT: int = Field(8094, env="API_PORT")

# API Management.
API_KEYS_FILE: str = Field("api_keys.json", env="API_KEYS_FILE")
Expand Down
3 changes: 3 additions & 0 deletions prompting/weight_setting/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
try:
PAST_WEIGHTS = [np.load(FILENAME)]
logger.info(f"Loaded weights from file: {PAST_WEIGHTS}")
except FileNotFoundError:
logger.info("No weights file found - this is expected on a new validator, starting with empty weights")
PAST_WEIGHTS = []
except Exception as ex:
logger.exception(f"Couldn't load weights from file: {ex}")
PAST_WEIGHTS = []
Expand Down
2 changes: 1 addition & 1 deletion scripts/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def combined_header_hook(request):
return openai.AsyncOpenAI(
base_url=f"http://localhost:{port}/v1",
max_retries=0,
timeout=Timeout(30, connect=10, read=20),
timeout=Timeout(60, connect=20, read=40),
http_client=openai.DefaultAsyncHttpxClient(event_hooks={"request": [combined_header_hook]}),
)

Expand Down
Loading