Skip to content

Commit

Permalink
API Release (#469)
Browse files Browse the repository at this point in the history
Co-authored-by: richwardle <richardwardle@macrocosmos.ai>
Co-authored-by: Hollyqui <felix.quinque@gmail.com>
Co-authored-by: richwardle <richard.wardle@macrocosmos.ai>
Co-authored-by: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com>
  • Loading branch information
5 people authored Dec 2, 2024
1 parent a225417 commit 2f50da2
Show file tree
Hide file tree
Showing 17 changed files with 479 additions and 44 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,5 @@ core
app.config.js
wandb
.vscode
api_keys.json
prompting/api/api_keys.json
1 change: 1 addition & 0 deletions api_keys.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
16 changes: 10 additions & 6 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from loguru import logger

from prompting import mutable_globals
from prompting.api.api import start_api
from prompting.base.dendrite import DendriteResponseEvent
from prompting.base.epistula import query_miners
from prompting.base.forward import log_stream_results
Expand Down Expand Up @@ -92,7 +93,7 @@ async def run_step(self, k: int, timeout: float) -> ValidatorLoggingEvent | Erro
if response_event is None:
logger.warning("No response event collected. This should not be happening.")
return
logger.debug(f"Collected responses in {timer.elapsed_time:.2f} seconds")
logger.debug(f"Collected responses in {timer.final_time:.2f} seconds")

# scoring_manager will score the responses as and when the correct model is loaded
task_scorer.add_to_queue(
Expand All @@ -108,7 +109,7 @@ async def run_step(self, k: int, timeout: float) -> ValidatorLoggingEvent | Erro
return ValidatorLoggingEvent(
block=self.estimate_block,
step=self.step,
step_time=timer.elapsed_time,
step_time=timer.final_time,
response_event=response_event,
task_id=task.task_id,
)
Expand Down Expand Up @@ -163,7 +164,7 @@ async def forward(self):
if not event:
return

event.forward_time = timer.elapsed_time
event.forward_time = timer.final_time

def __enter__(self):
if settings.NO_BACKGROUND_THREAD:
Expand Down Expand Up @@ -196,6 +197,9 @@ def __exit__(self, exc_type, exc_value, traceback):


async def main():
if settings.DEPLOY_API:
asyncio.create_task(start_api())

GPUInfo.log_gpu_info()
# start profiling
asyncio.create_task(profiler.print_stats())
Expand All @@ -214,9 +218,9 @@ async def main():

# start scoring tasks in separate loop
asyncio.create_task(task_scorer.start())
# TODO: Think about whether we want to store the task queue locally in case of a crash
# TODO: Possibly run task scorer & model scheduler with a lock so I don't unload a model whilst it's generating
# TODO: Make weight setting happen as specific intervals as we load/unload models
# # TODO: Think about whether we want to store the task queue locally in case of a crash
# # TODO: Possibly run task scorer & model scheduler with a lock so I don't unload a model whilst it's generating
# # TODO: Make weight setting happen as specific intervals as we load/unload models
with Validator() as v:
while True:
logger.info(
Expand Down
28 changes: 28 additions & 0 deletions prompting/api/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This ensures uvicorn is imported first
import uvicorn
from fastapi import FastAPI
from loguru import logger

# Now we can safely import the rest
from prompting.api.api_managements.api import router as api_management_router
from prompting.api.gpt_endpoints.api import router as gpt_router
from prompting.api.miner_availabilities.api import router as miner_availabilities_router

app = FastAPI()

# Add routers at the application level
app.include_router(api_management_router, prefix="/api_management", tags=["api_management"])
app.include_router(miner_availabilities_router, prefix="/miner_availabilities", tags=["miner_availabilities"])
app.include_router(gpt_router, tags=["gpt"])


@app.get("/health")
def health():
logger.info("Health endpoint accessed.")
return {"status": "healthy"}


# if __name__ == "__main__":
async def start_api():
logger.info("Starting API...")
uvicorn.run("prompting.api.api:app", host="0.0.0.0", port=8004, loop="asyncio", reload=False)
1 change: 1 addition & 0 deletions prompting/api/api_keys.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
82 changes: 82 additions & 0 deletions prompting/api/api_managements/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import json
import secrets

from fastapi import APIRouter, Depends, Header, HTTPException
from loguru import logger

from prompting.settings import settings

router = APIRouter()


# Load and save functions for API keys
def load_api_keys():
try:
with open(settings.API_KEYS_FILE, "r") as f:
return json.load(f)
except FileNotFoundError:
return {}


def save_api_keys(api_keys):
with open(settings.API_KEYS_FILE, "w") as f:
json.dump(api_keys, f)


# Use lifespan to initialize API keys
_keys = load_api_keys()
logger.info(f"Loaded API keys: {_keys}")
save_api_keys(_keys)


# Dependency to validate the admin key
def validate_admin_key(admin_key: str = Header(...)):
if admin_key != settings.ADMIN_KEY:
raise HTTPException(status_code=403, detail="Invalid admin key")


# Dependency to validate API keys
def validate_api_key(api_key: str = Header(...)):
if api_key not in _keys:
raise HTTPException(status_code=403, detail="Invalid API key")
return _keys[api_key]


@router.post("/create-api-key/")
def create_api_key(rate_limit: int, admin_key: str = Depends(validate_admin_key)):
"""Creates a new API key with a specified rate limit."""
new_api_key = secrets.token_hex(16)
_keys[new_api_key] = {"rate_limit": rate_limit, "usage": 0}
save_api_keys(_keys)
return {"message": "API key created", "api_key": new_api_key}


@router.put("/modify-api-key/{api_key}")
def modify_api_key(api_key: str, rate_limit: int, admin_key: str = Depends(validate_admin_key)):
"""Modifies the rate limit of an existing API key."""
if api_key not in _keys:
raise HTTPException(status_code=404, detail="API key not found")
_keys[api_key]["rate_limit"] = rate_limit
save_api_keys(_keys)
return {"message": "API key updated", "api_key": api_key}


@router.delete("/delete-api-key/{api_key}")
def delete_api_key(api_key: str, admin_key: str = Depends(validate_admin_key)):
"""Deletes an existing API key."""
if api_key not in _keys:
raise HTTPException(status_code=404, detail="API key not found")
del _keys[api_key]
save_api_keys(_keys)
return {"message": "API key deleted"}


@router.get("/demo-endpoint/")
def demo_endpoint(api_key_data: dict = Depends(validate_api_key)):
"""A demo endpoint that requires a valid API key."""
return {"message": "Access granted", "your_rate_limit": api_key_data["rate_limit"]}


# # Create FastAPI app and include the router
# app = FastAPI()
# app.include_router(router)
117 changes: 117 additions & 0 deletions prompting/api/gpt_endpoints/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import json
import random
from typing import AsyncGenerator

from fastapi import APIRouter, Depends, HTTPException, Request
from loguru import logger

from prompting.api.api_managements.api import validate_api_key
from prompting.base.dendrite import DendriteResponseEvent, SynapseStreamResult
from prompting.base.epistula import query_miners
from prompting.miner_availability.miner_availability import miner_availabilities
from prompting.rewards.scoring import task_scorer
from prompting.settings import settings
from prompting.tasks.inference import InferenceTask
from prompting.tasks.task_registry import TaskRegistry
from prompting.utils.timer import Timer

router = APIRouter()


async def process_and_collect_stream(miner_id: int, request: dict, response: AsyncGenerator):
collected_content = []
collected_chunks_timings = []
with Timer() as timer:
async for chunk in response:
logger.debug(f"Chunk: {chunk}")
if hasattr(chunk, "choices") and chunk.choices and isinstance(chunk.choices[0].delta.content, str):
collected_content.append(chunk.choices[0].delta.content)
collected_chunks_timings.append(timer.elapsed_time())
yield f"data: {json.dumps(chunk.model_dump())}\n\n"

task = InferenceTask(
query=request["messages"][-1]["content"],
messages=[message["content"] for message in request["messages"]],
model=request.get("model"),
seed=request.get("seed"),
response="".join(collected_content),
)
logger.debug(f"Adding Organic Request to scoring queue: {task}")
response_event = DendriteResponseEvent(
stream_results=[
SynapseStreamResult(
uid=miner_id,
accumulated_chunks=collected_content,
accumulated_chunks_timings=collected_chunks_timings,
)
],
uids=[miner_id],
timeout=settings.NEURON_TIMEOUT,
completions=["".join(collected_content)],
)

task_scorer.add_to_queue(
task=task, response=response_event, dataset_entry=task.dataset_entry, block=-1, step=-1, task_id=task.task_id
)
yield "data: [DONE]\n\n"


@router.post("/mixture_of_agents")
async def mixture_of_agents(request: Request, api_key_data: dict = Depends(validate_api_key)):
return {"message": "Mixture of Agents"}


@router.post("/v1/chat/completions")
async def proxy_chat_completions(request: Request, api_key_data: dict = Depends(validate_api_key)):
body = await request.json()
task = TaskRegistry.get_task_by_name(body.get("task"))
if body.get("task") and not task:
raise HTTPException(status_code=400, detail=f"Task {body.get('task')} not found")
logger.debug(f"Requested Task: {body.get('task')}, {task}")

stream = body.get("stream")
body = {k: v for k, v in body.items() if k not in ["task", "stream"]}
body["task"] = task.__class__.__name__
body["seed"] = body.get("seed") or str(random.randint(0, 1_000_000))
logger.debug(f"Seed provided by miner: {bool(body.get('seed'))} -- Using seed: {body.get('seed')}")

if settings.TEST_MINER_IDS:
available_miners = settings.TEST_MINER_IDS
elif not settings.mode == "mock" and not (
available_miners := miner_availabilities.get_available_miners(task=task, model=body.get("model"))
):
raise HTTPException(
status_code=503,
detail=f"No miners available for model: {body.get('model')} and task: {task.__name__}",
)

response = query_miners(available_miners, json.dumps(body).encode("utf-8"), stream=stream)
if stream:
return response
else:
response = await response
response_event = DendriteResponseEvent(
stream_results=response,
uids=available_miners,
timeout=settings.NEURON_TIMEOUT,
completions=["".join(res.accumulated_chunks) for res in response],
)

task = task(
query=body["messages"][-1]["content"],
messages=[message["content"] for message in body["messages"]],
model=body.get("model"),
seed=body.get("seed"),
response=response_event,
)

task_scorer.add_to_queue(
task=task,
response=response_event,
dataset_entry=task.dataset_entry,
block=-1,
step=-1,
task_id=task.task_id,
)

return [res.model_dump() for res in response]
Empty file.
29 changes: 29 additions & 0 deletions prompting/api/miner_availabilities/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Literal

from fastapi import APIRouter
from loguru import logger

from prompting.miner_availability.miner_availability import miner_availabilities
from prompting.tasks.task_registry import TaskRegistry

router = APIRouter()


@router.post("/miner_availabilities")
async def get_miner_availabilities(uids: list[int] | None = None):
if uids:
return {uid: miner_availabilities.miners.get(uid) for uid in uids}
logger.info(f"Returning all miner availabilities for {len(miner_availabilities.miners)} miners")
return miner_availabilities.miners


@router.get("/get_available_miners")
async def get_available_miners(
task: Literal[tuple([config.task.__name__ for config in TaskRegistry.task_configs])] | None = None,
model: str | None = None,
k: int = 10,
):
logger.info(f"Getting {k} available miners for task {task} and model {model}")
task_configs = [config for config in TaskRegistry.task_configs if config.task.__name__ == task]
task_config = task_configs[0] if task_configs else None
return miner_availabilities.get_available_miners(task=task_config, model=model, k=k)
Loading

0 comments on commit 2f50da2

Please sign in to comment.