generated from opentensor/bittensor-subnet-template
-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: richwardle <richardwardle@macrocosmos.ai> Co-authored-by: Hollyqui <felix.quinque@gmail.com> Co-authored-by: richwardle <richard.wardle@macrocosmos.ai> Co-authored-by: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com>
- Loading branch information
1 parent
a225417
commit 2f50da2
Showing
17 changed files
with
479 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,3 +177,5 @@ core | |
app.config.js | ||
wandb | ||
.vscode | ||
api_keys.json | ||
prompting/api/api_keys.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# This ensures uvicorn is imported first | ||
import uvicorn | ||
from fastapi import FastAPI | ||
from loguru import logger | ||
|
||
# Now we can safely import the rest | ||
from prompting.api.api_managements.api import router as api_management_router | ||
from prompting.api.gpt_endpoints.api import router as gpt_router | ||
from prompting.api.miner_availabilities.api import router as miner_availabilities_router | ||
|
||
app = FastAPI() | ||
|
||
# Add routers at the application level | ||
app.include_router(api_management_router, prefix="/api_management", tags=["api_management"]) | ||
app.include_router(miner_availabilities_router, prefix="/miner_availabilities", tags=["miner_availabilities"]) | ||
app.include_router(gpt_router, tags=["gpt"]) | ||
|
||
|
||
@app.get("/health") | ||
def health(): | ||
logger.info("Health endpoint accessed.") | ||
return {"status": "healthy"} | ||
|
||
|
||
# if __name__ == "__main__": | ||
async def start_api(): | ||
logger.info("Starting API...") | ||
uvicorn.run("prompting.api.api:app", host="0.0.0.0", port=8004, loop="asyncio", reload=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import json | ||
import secrets | ||
|
||
from fastapi import APIRouter, Depends, Header, HTTPException | ||
from loguru import logger | ||
|
||
from prompting.settings import settings | ||
|
||
router = APIRouter() | ||
|
||
|
||
# Load and save functions for API keys | ||
def load_api_keys(): | ||
try: | ||
with open(settings.API_KEYS_FILE, "r") as f: | ||
return json.load(f) | ||
except FileNotFoundError: | ||
return {} | ||
|
||
|
||
def save_api_keys(api_keys): | ||
with open(settings.API_KEYS_FILE, "w") as f: | ||
json.dump(api_keys, f) | ||
|
||
|
||
# Use lifespan to initialize API keys | ||
_keys = load_api_keys() | ||
logger.info(f"Loaded API keys: {_keys}") | ||
save_api_keys(_keys) | ||
|
||
|
||
# Dependency to validate the admin key | ||
def validate_admin_key(admin_key: str = Header(...)): | ||
if admin_key != settings.ADMIN_KEY: | ||
raise HTTPException(status_code=403, detail="Invalid admin key") | ||
|
||
|
||
# Dependency to validate API keys | ||
def validate_api_key(api_key: str = Header(...)): | ||
if api_key not in _keys: | ||
raise HTTPException(status_code=403, detail="Invalid API key") | ||
return _keys[api_key] | ||
|
||
|
||
@router.post("/create-api-key/") | ||
def create_api_key(rate_limit: int, admin_key: str = Depends(validate_admin_key)): | ||
"""Creates a new API key with a specified rate limit.""" | ||
new_api_key = secrets.token_hex(16) | ||
_keys[new_api_key] = {"rate_limit": rate_limit, "usage": 0} | ||
save_api_keys(_keys) | ||
return {"message": "API key created", "api_key": new_api_key} | ||
|
||
|
||
@router.put("/modify-api-key/{api_key}") | ||
def modify_api_key(api_key: str, rate_limit: int, admin_key: str = Depends(validate_admin_key)): | ||
"""Modifies the rate limit of an existing API key.""" | ||
if api_key not in _keys: | ||
raise HTTPException(status_code=404, detail="API key not found") | ||
_keys[api_key]["rate_limit"] = rate_limit | ||
save_api_keys(_keys) | ||
return {"message": "API key updated", "api_key": api_key} | ||
|
||
|
||
@router.delete("/delete-api-key/{api_key}") | ||
def delete_api_key(api_key: str, admin_key: str = Depends(validate_admin_key)): | ||
"""Deletes an existing API key.""" | ||
if api_key not in _keys: | ||
raise HTTPException(status_code=404, detail="API key not found") | ||
del _keys[api_key] | ||
save_api_keys(_keys) | ||
return {"message": "API key deleted"} | ||
|
||
|
||
@router.get("/demo-endpoint/") | ||
def demo_endpoint(api_key_data: dict = Depends(validate_api_key)): | ||
"""A demo endpoint that requires a valid API key.""" | ||
return {"message": "Access granted", "your_rate_limit": api_key_data["rate_limit"]} | ||
|
||
|
||
# # Create FastAPI app and include the router | ||
# app = FastAPI() | ||
# app.include_router(router) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import json | ||
import random | ||
from typing import AsyncGenerator | ||
|
||
from fastapi import APIRouter, Depends, HTTPException, Request | ||
from loguru import logger | ||
|
||
from prompting.api.api_managements.api import validate_api_key | ||
from prompting.base.dendrite import DendriteResponseEvent, SynapseStreamResult | ||
from prompting.base.epistula import query_miners | ||
from prompting.miner_availability.miner_availability import miner_availabilities | ||
from prompting.rewards.scoring import task_scorer | ||
from prompting.settings import settings | ||
from prompting.tasks.inference import InferenceTask | ||
from prompting.tasks.task_registry import TaskRegistry | ||
from prompting.utils.timer import Timer | ||
|
||
router = APIRouter() | ||
|
||
|
||
async def process_and_collect_stream(miner_id: int, request: dict, response: AsyncGenerator): | ||
collected_content = [] | ||
collected_chunks_timings = [] | ||
with Timer() as timer: | ||
async for chunk in response: | ||
logger.debug(f"Chunk: {chunk}") | ||
if hasattr(chunk, "choices") and chunk.choices and isinstance(chunk.choices[0].delta.content, str): | ||
collected_content.append(chunk.choices[0].delta.content) | ||
collected_chunks_timings.append(timer.elapsed_time()) | ||
yield f"data: {json.dumps(chunk.model_dump())}\n\n" | ||
|
||
task = InferenceTask( | ||
query=request["messages"][-1]["content"], | ||
messages=[message["content"] for message in request["messages"]], | ||
model=request.get("model"), | ||
seed=request.get("seed"), | ||
response="".join(collected_content), | ||
) | ||
logger.debug(f"Adding Organic Request to scoring queue: {task}") | ||
response_event = DendriteResponseEvent( | ||
stream_results=[ | ||
SynapseStreamResult( | ||
uid=miner_id, | ||
accumulated_chunks=collected_content, | ||
accumulated_chunks_timings=collected_chunks_timings, | ||
) | ||
], | ||
uids=[miner_id], | ||
timeout=settings.NEURON_TIMEOUT, | ||
completions=["".join(collected_content)], | ||
) | ||
|
||
task_scorer.add_to_queue( | ||
task=task, response=response_event, dataset_entry=task.dataset_entry, block=-1, step=-1, task_id=task.task_id | ||
) | ||
yield "data: [DONE]\n\n" | ||
|
||
|
||
@router.post("/mixture_of_agents") | ||
async def mixture_of_agents(request: Request, api_key_data: dict = Depends(validate_api_key)): | ||
return {"message": "Mixture of Agents"} | ||
|
||
|
||
@router.post("/v1/chat/completions") | ||
async def proxy_chat_completions(request: Request, api_key_data: dict = Depends(validate_api_key)): | ||
body = await request.json() | ||
task = TaskRegistry.get_task_by_name(body.get("task")) | ||
if body.get("task") and not task: | ||
raise HTTPException(status_code=400, detail=f"Task {body.get('task')} not found") | ||
logger.debug(f"Requested Task: {body.get('task')}, {task}") | ||
|
||
stream = body.get("stream") | ||
body = {k: v for k, v in body.items() if k not in ["task", "stream"]} | ||
body["task"] = task.__class__.__name__ | ||
body["seed"] = body.get("seed") or str(random.randint(0, 1_000_000)) | ||
logger.debug(f"Seed provided by miner: {bool(body.get('seed'))} -- Using seed: {body.get('seed')}") | ||
|
||
if settings.TEST_MINER_IDS: | ||
available_miners = settings.TEST_MINER_IDS | ||
elif not settings.mode == "mock" and not ( | ||
available_miners := miner_availabilities.get_available_miners(task=task, model=body.get("model")) | ||
): | ||
raise HTTPException( | ||
status_code=503, | ||
detail=f"No miners available for model: {body.get('model')} and task: {task.__name__}", | ||
) | ||
|
||
response = query_miners(available_miners, json.dumps(body).encode("utf-8"), stream=stream) | ||
if stream: | ||
return response | ||
else: | ||
response = await response | ||
response_event = DendriteResponseEvent( | ||
stream_results=response, | ||
uids=available_miners, | ||
timeout=settings.NEURON_TIMEOUT, | ||
completions=["".join(res.accumulated_chunks) for res in response], | ||
) | ||
|
||
task = task( | ||
query=body["messages"][-1]["content"], | ||
messages=[message["content"] for message in body["messages"]], | ||
model=body.get("model"), | ||
seed=body.get("seed"), | ||
response=response_event, | ||
) | ||
|
||
task_scorer.add_to_queue( | ||
task=task, | ||
response=response_event, | ||
dataset_entry=task.dataset_entry, | ||
block=-1, | ||
step=-1, | ||
task_id=task.task_id, | ||
) | ||
|
||
return [res.model_dump() for res in response] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from typing import Literal | ||
|
||
from fastapi import APIRouter | ||
from loguru import logger | ||
|
||
from prompting.miner_availability.miner_availability import miner_availabilities | ||
from prompting.tasks.task_registry import TaskRegistry | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.post("/miner_availabilities") | ||
async def get_miner_availabilities(uids: list[int] | None = None): | ||
if uids: | ||
return {uid: miner_availabilities.miners.get(uid) for uid in uids} | ||
logger.info(f"Returning all miner availabilities for {len(miner_availabilities.miners)} miners") | ||
return miner_availabilities.miners | ||
|
||
|
||
@router.get("/get_available_miners") | ||
async def get_available_miners( | ||
task: Literal[tuple([config.task.__name__ for config in TaskRegistry.task_configs])] | None = None, | ||
model: str | None = None, | ||
k: int = 10, | ||
): | ||
logger.info(f"Getting {k} available miners for task {task} and model {model}") | ||
task_configs = [config for config in TaskRegistry.task_configs if config.task.__name__ == task] | ||
task_config = task_configs[0] if task_configs else None | ||
return miner_availabilities.get_available_miners(task=task_config, model=model, k=k) |
Oops, something went wrong.