Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API Release #469

Merged
merged 35 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4457d4e
Initial upload
Nov 19, 2024
147aa27
Merge pull request #450 from macrocosm-os/staging
bkb2135 Nov 19, 2024
62ae30c
Get everything working
Nov 19, 2024
28adabb
Merge pull request #460 from macrocosm-os/staging
bkb2135 Nov 20, 2024
8e25f33
SN1-331: Adding initial draft for endpoints
Hollyqui Nov 20, 2024
ee5351f
SN1-331: Adding API keys
Hollyqui Nov 22, 2024
4ebe74d
Adding test miner ids
Hollyqui Nov 22, 2024
20e5376
Adding tasks to scoring queue
Hollyqui Nov 24, 2024
37a8874
Enabling non-streaming response + bug fixes
Hollyqui Nov 25, 2024
a09cd9b
Making model loading non-blocking
Hollyqui Nov 25, 2024
566dc77
Protecting endpoints with API key
Hollyqui Nov 26, 2024
9d810ce
Improving error messages + improving API key saving
Hollyqui Nov 26, 2024
7497354
Merge pull request #461 from macrocosm-os/staging
bkb2135 Nov 26, 2024
685290a
Signing epistula properly for recipient
Hollyqui Nov 26, 2024
7296fe7
Passing task type
Hollyqui Nov 26, 2024
d497ee6
Merge pull request #465 from macrocosm-os/staging
bkb2135 Nov 26, 2024
622427a
Merge branch 'main' into kalei/api-working-branch
bkb2135 Nov 26, 2024
120a90a
Move streaming of miners into query_miners function
bkb2135 Nov 26, 2024
6cf7aa5
Use query_miners in api
bkb2135 Nov 27, 2024
07620fd
Fix syntax errors
bkb2135 Nov 27, 2024
21c2236
Manually dump models
bkb2135 Nov 27, 2024
ba900a2
Use autoawq 0.2.0
richwardle Nov 27, 2024
26d1db1
Support delta or message in sn19 response
richwardle Nov 27, 2024
0389d1b
Remove Unecessary Line
richwardle Nov 27, 2024
351c14c
Formatting
bkb2135 Nov 27, 2024
b132d13
Merge pull request #467 from macrocosm-os/hotfix/support-multiple-sn1…
bkb2135 Nov 27, 2024
0f6bfd7
Update pyproject.toml
bkb2135 Nov 27, 2024
fd186a9
Merge remote-tracking branch 'origin/release/v2.13.2' into kalei/api-…
bkb2135 Nov 27, 2024
9a037cc
Add test_api to scripts
bkb2135 Nov 28, 2024
0b37518
Update api_keys.json
bkb2135 Nov 28, 2024
9485e56
Update prompting/api/gpt_endpoints/api.py
bkb2135 Nov 28, 2024
1bf3996
Push Working Changes
richwardle Nov 28, 2024
6bab37e
Add Optional Api Deployment
bkb2135 Nov 28, 2024
bb115cf
Fixing formatting
Hollyqui Dec 2, 2024
09e4103
sort: fix import formatting
richwardle Dec 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,5 @@ core
app.config.js
wandb
.vscode
api_keys.json
prompting/api/api_keys.json
1 change: 1 addition & 0 deletions api_keys.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
16 changes: 10 additions & 6 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from loguru import logger

from prompting import mutable_globals
from prompting.api.api import start_api
from prompting.base.dendrite import DendriteResponseEvent
from prompting.base.epistula import query_miners
from prompting.base.forward import log_stream_results
Expand Down Expand Up @@ -92,7 +93,7 @@ async def run_step(self, k: int, timeout: float) -> ValidatorLoggingEvent | Erro
if response_event is None:
logger.warning("No response event collected. This should not be happening.")
return
logger.debug(f"Collected responses in {timer.elapsed_time:.2f} seconds")
logger.debug(f"Collected responses in {timer.final_time:.2f} seconds")

# scoring_manager will score the responses as and when the correct model is loaded
task_scorer.add_to_queue(
Expand All @@ -108,7 +109,7 @@ async def run_step(self, k: int, timeout: float) -> ValidatorLoggingEvent | Erro
return ValidatorLoggingEvent(
block=self.estimate_block,
step=self.step,
step_time=timer.elapsed_time,
step_time=timer.final_time,
response_event=response_event,
task_id=task.task_id,
)
Expand Down Expand Up @@ -163,7 +164,7 @@ async def forward(self):
if not event:
return

event.forward_time = timer.elapsed_time
event.forward_time = timer.final_time

def __enter__(self):
if settings.NO_BACKGROUND_THREAD:
Expand Down Expand Up @@ -196,6 +197,9 @@ def __exit__(self, exc_type, exc_value, traceback):


async def main():
if settings.DEPLOY_API:
asyncio.create_task(start_api())

GPUInfo.log_gpu_info()
# start profiling
asyncio.create_task(profiler.print_stats())
Expand All @@ -214,9 +218,9 @@ async def main():

# start scoring tasks in separate loop
asyncio.create_task(task_scorer.start())
# TODO: Think about whether we want to store the task queue locally in case of a crash
# TODO: Possibly run task scorer & model scheduler with a lock so I don't unload a model whilst it's generating
# TODO: Make weight setting happen as specific intervals as we load/unload models
# # TODO: Think about whether we want to store the task queue locally in case of a crash
# # TODO: Possibly run task scorer & model scheduler with a lock so I don't unload a model whilst it's generating
# # TODO: Make weight setting happen as specific intervals as we load/unload models
with Validator() as v:
while True:
logger.info(
Expand Down
340 changes: 184 additions & 156 deletions poetry.lock

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions prompting/api/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This ensures uvicorn is imported first
import uvicorn
from fastapi import FastAPI
from loguru import logger

# Now we can safely import the rest
from prompting.api.api_managements.api import router as api_management_router
from prompting.api.gpt_endpoints.api import router as gpt_router
from prompting.api.miner_availabilities.api import router as miner_availabilities_router

app = FastAPI()

# Add routers at the application level
app.include_router(api_management_router, prefix="/api_management", tags=["api_management"])
app.include_router(miner_availabilities_router, prefix="/miner_availabilities", tags=["miner_availabilities"])
app.include_router(gpt_router, tags=["gpt"])


@app.get("/health")
def health():
logger.info("Health endpoint accessed.")
return {"status": "healthy"}


# if __name__ == "__main__":
async def start_api():
logger.info("Starting API...")
uvicorn.run("prompting.api.api:app", host="0.0.0.0", port=8004, loop="asyncio", reload=False)
1 change: 1 addition & 0 deletions prompting/api/api_keys.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
82 changes: 82 additions & 0 deletions prompting/api/api_managements/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import json
import secrets

from fastapi import APIRouter, Depends, Header, HTTPException
from loguru import logger

from prompting.settings import settings

router = APIRouter()


# Load and save functions for API keys
def load_api_keys():
try:
with open(settings.API_KEYS_FILE, "r") as f:
return json.load(f)
except FileNotFoundError:
return {}


def save_api_keys(api_keys):
with open(settings.API_KEYS_FILE, "w") as f:
json.dump(api_keys, f)


# Use lifespan to initialize API keys
_keys = load_api_keys()
logger.info(f"Loaded API keys: {_keys}")
save_api_keys(_keys)


# Dependency to validate the admin key
def validate_admin_key(admin_key: str = Header(...)):
if admin_key != settings.ADMIN_KEY:
raise HTTPException(status_code=403, detail="Invalid admin key")


# Dependency to validate API keys
def validate_api_key(api_key: str = Header(...)):
if api_key not in _keys:
raise HTTPException(status_code=403, detail="Invalid API key")
return _keys[api_key]


@router.post("/create-api-key/")
def create_api_key(rate_limit: int, admin_key: str = Depends(validate_admin_key)):
"""Creates a new API key with a specified rate limit."""
new_api_key = secrets.token_hex(16)
_keys[new_api_key] = {"rate_limit": rate_limit, "usage": 0}
save_api_keys(_keys)
return {"message": "API key created", "api_key": new_api_key}


@router.put("/modify-api-key/{api_key}")
def modify_api_key(api_key: str, rate_limit: int, admin_key: str = Depends(validate_admin_key)):
"""Modifies the rate limit of an existing API key."""
if api_key not in _keys:
raise HTTPException(status_code=404, detail="API key not found")
_keys[api_key]["rate_limit"] = rate_limit
save_api_keys(_keys)
return {"message": "API key updated", "api_key": api_key}


@router.delete("/delete-api-key/{api_key}")
def delete_api_key(api_key: str, admin_key: str = Depends(validate_admin_key)):
"""Deletes an existing API key."""
if api_key not in _keys:
raise HTTPException(status_code=404, detail="API key not found")
del _keys[api_key]
save_api_keys(_keys)
return {"message": "API key deleted"}


@router.get("/demo-endpoint/")
def demo_endpoint(api_key_data: dict = Depends(validate_api_key)):
"""A demo endpoint that requires a valid API key."""
return {"message": "Access granted", "your_rate_limit": api_key_data["rate_limit"]}


# # Create FastAPI app and include the router
# app = FastAPI()
# app.include_router(router)
117 changes: 117 additions & 0 deletions prompting/api/gpt_endpoints/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import json
import random
from typing import AsyncGenerator

from fastapi import APIRouter, Depends, HTTPException, Request
from loguru import logger

from prompting.api.api_managements.api import validate_api_key
from prompting.base.dendrite import DendriteResponseEvent, SynapseStreamResult
from prompting.base.epistula import query_miners
from prompting.miner_availability.miner_availability import miner_availabilities
from prompting.rewards.scoring import task_scorer
from prompting.settings import settings
from prompting.tasks.inference import InferenceTask
from prompting.tasks.task_registry import TaskRegistry
from prompting.utils.timer import Timer

router = APIRouter()


async def process_and_collect_stream(miner_id: int, request: dict, response: AsyncGenerator):
collected_content = []
collected_chunks_timings = []
with Timer() as timer:
async for chunk in response:
logger.debug(f"Chunk: {chunk}")
if hasattr(chunk, "choices") and chunk.choices and isinstance(chunk.choices[0].delta.content, str):
collected_content.append(chunk.choices[0].delta.content)
collected_chunks_timings.append(timer.elapsed_time())
yield f"data: {json.dumps(chunk.model_dump())}\n\n"

task = InferenceTask(
query=request["messages"][-1]["content"],
messages=[message["content"] for message in request["messages"]],
model=request.get("model"),
seed=request.get("seed"),
response="".join(collected_content),
)
logger.debug(f"Adding Organic Request to scoring queue: {task}")
response_event = DendriteResponseEvent(
stream_results=[
SynapseStreamResult(
uid=miner_id,
accumulated_chunks=collected_content,
accumulated_chunks_timings=collected_chunks_timings,
)
],
uids=[miner_id],
timeout=settings.NEURON_TIMEOUT,
completions=["".join(collected_content)],
)

task_scorer.add_to_queue(
task=task, response=response_event, dataset_entry=task.dataset_entry, block=-1, step=-1, task_id=task.task_id
)
yield "data: [DONE]\n\n"


@router.post("/mixture_of_agents")
async def mixture_of_agents(request: Request, api_key_data: dict = Depends(validate_api_key)):
return {"message": "Mixture of Agents"}


@router.post("/v1/chat/completions")
async def proxy_chat_completions(request: Request, api_key_data: dict = Depends(validate_api_key)):
body = await request.json()
task = TaskRegistry.get_task_by_name(body.get("task"))
if body.get("task") and not task:
raise HTTPException(status_code=400, detail=f"Task {body.get('task')} not found")
logger.debug(f"Requested Task: {body.get('task')}, {task}")

stream = body.get("stream")
body = {k: v for k, v in body.items() if k not in ["task", "stream"]}
body["task"] = task.__class__.__name__
body["seed"] = body.get("seed") or str(random.randint(0, 1_000_000))
logger.debug(f"Seed provided by miner: {bool(body.get('seed'))} -- Using seed: {body.get('seed')}")

if settings.TEST_MINER_IDS:
available_miners = settings.TEST_MINER_IDS
elif not settings.mode == "mock" and not (
available_miners := miner_availabilities.get_available_miners(task=task, model=body.get("model"))
):
raise HTTPException(
status_code=503,
detail=f"No miners available for model: {body.get('model')} and task: {task.__name__}",
)

response = query_miners(available_miners, json.dumps(body).encode("utf-8"), stream=stream)
if stream:
return response
else:
response = await response
response_event = DendriteResponseEvent(
stream_results=response,
uids=available_miners,
timeout=settings.NEURON_TIMEOUT,
completions=["".join(res.accumulated_chunks) for res in response],
)

task = task(
query=body["messages"][-1]["content"],
messages=[message["content"] for message in body["messages"]],
model=body.get("model"),
seed=body.get("seed"),
response=response_event,
)

task_scorer.add_to_queue(
task=task,
response=response_event,
dataset_entry=task.dataset_entry,
block=-1,
step=-1,
task_id=task.task_id,
)

return [res.model_dump() for res in response]
Empty file.
29 changes: 29 additions & 0 deletions prompting/api/miner_availabilities/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Literal

from fastapi import APIRouter
from loguru import logger

from prompting.miner_availability.miner_availability import miner_availabilities
from prompting.tasks.task_registry import TaskRegistry

router = APIRouter()


@router.post("/miner_availabilities")
async def get_miner_availabilities(uids: list[int] | None = None):
if uids:
return {uid: miner_availabilities.miners.get(uid) for uid in uids}
logger.info(f"Returning all miner availabilities for {len(miner_availabilities.miners)} miners")
return miner_availabilities.miners


@router.get("/get_available_miners")
async def get_available_miners(
task: Literal[tuple([config.task.__name__ for config in TaskRegistry.task_configs])] | None = None,
model: str | None = None,
k: int = 10,
):
logger.info(f"Getting {k} available miners for task {task} and model {model}")
task_configs = [config for config in TaskRegistry.task_configs if config.task.__name__ == task]
task_config = task_configs[0] if task_configs else None
return miner_availabilities.get_available_miners(task=task_config, model=model, k=k)
Loading