Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/change-transformers-version #515

Merged
merged 8 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .env.api.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
API_PORT = "8005"
API_HOST = "0.0.0.0"
# SCORING_KEY = "YOUR_SCORING_API_KEY_GOES_HERE"
API_PORT = "42170" # Port for the API server
API_HOST = "0.0.0.0" # Host for the API server
SCORING_KEY = "123" # The scoring key for the validator (must match the scoring key in the .env.validator file)
SCORE_ORGANICS = True # Whether to score organics
VALIDATOR_API = "0.0.0.0:8094" # The validator API to forward responses to for scoring
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ wandb
.vscode
api_keys.json
prompting/api/api_keys.json
weights.csv
12 changes: 12 additions & 0 deletions api.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module.exports = {
apps: [
{
name: 'api_server',
script: 'poetry',
interpreter: 'none',
args: ['run', 'python', 'validator_api/api.py'],
min_uptime: '5m',
max_restarts: 5
}
]
};
2 changes: 1 addition & 1 deletion api_keys.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{}
{}
145 changes: 107 additions & 38 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import multiprocessing as mp
import time

from loguru import logger
import loguru
import torch

from prompting.api.api import start_scoring_api
from prompting.llms.model_manager import model_scheduler
Expand All @@ -20,48 +21,116 @@
from prompting.weight_setting.weight_setter import weight_setter
from shared.profiling import profiler

# Add a handler to write logs to a file
loguru.logger.add("logfile.log", rotation="1000 MB", retention="10 days", level="DEBUG")
from loguru import logger

torch.multiprocessing.set_start_method("spawn", force=True)

NEURON_SAMPLE_SIZE = 100


def create_loop_process(task_queue, scoring_queue, reward_events):
async def spawn_loops(task_queue, scoring_queue, reward_events):
logger.info("Starting Profiler...")
asyncio.create_task(profiler.print_stats(), name="Profiler"),
logger.info("Starting ModelScheduler...")
asyncio.create_task(model_scheduler.start(scoring_queue), name="ModelScheduler"),
logger.info("Starting TaskScorer...")
asyncio.create_task(task_scorer.start(scoring_queue, reward_events), name="TaskScorer"),
logger.info("Starting WeightSetter...")
asyncio.create_task(weight_setter.start(reward_events))

# Main monitoring loop
start = time.time()

logger.info("Starting Main Monitoring Loop...")
while True:
await asyncio.sleep(5)
current_time = time.time()
time_diff = current_time - start
start = current_time

# Check if all tasks are still running
logger.debug(f"Running {time_diff:.2f} seconds")
logger.debug(f"Number of tasks in Task Queue: {len(task_queue)}")
logger.debug(f"Number of tasks in Scoring Queue: {len(scoring_queue)}")
logger.debug(f"Number of tasks in Reward Events: {len(reward_events)}")

asyncio.run(spawn_loops(task_queue, scoring_queue, reward_events))


def start_api():
async def start():
await start_scoring_api()
while True:
await asyncio.sleep(10)
logger.debug("Running API...")

asyncio.run(start())


def create_task_loop(task_queue, scoring_queue):
async def start(task_queue, scoring_queue):
logger.info("Starting AvailabilityCheckingLoop...")
asyncio.create_task(availability_checking_loop.start())

logger.info("Starting TaskSender...")
asyncio.create_task(task_sender.start(task_queue, scoring_queue))

logger.info("Starting TaskLoop...")
asyncio.create_task(task_loop.start(task_queue, scoring_queue))
while True:
await asyncio.sleep(10)
logger.debug("Running task loop...")

asyncio.run(start(task_queue, scoring_queue))


async def main():
# will start checking the availability of miners at regular intervals, needed for API and Validator
asyncio.create_task(availability_checking_loop.start())

if shared_settings.DEPLOY_SCORING_API:
# Use multiprocessing to bypass API blocking issue.
api_process = mp.Process(target=lambda: asyncio.run(start_scoring_api()))
api_process.start()

GPUInfo.log_gpu_info()
# start profiling
asyncio.create_task(profiler.print_stats())

# start rotating LLM models
asyncio.create_task(model_scheduler.start())

# start creating tasks
asyncio.create_task(task_loop.start())

# will start checking the availability of miners at regular intervals
asyncio.create_task(availability_checking_loop.start())

# start sending tasks to miners
asyncio.create_task(task_sender.start())

# sets weights at regular intervals (synchronised between all validators)
asyncio.create_task(weight_setter.start())

# start scoring tasks in separate loop
asyncio.create_task(task_scorer.start())
# # TODO: Think about whether we want to store the task queue locally in case of a crash
# # TODO: Possibly run task scorer & model scheduler with a lock so I don't unload a model whilst it's generating
# # TODO: Make weight setting happen as specific intervals as we load/unload models
start = time.time()
await asyncio.sleep(60)
while True:
await asyncio.sleep(5)
time_diff = -start + (start := time.time())
logger.debug(f"Running {time_diff:.2f} seconds")
with torch.multiprocessing.Manager() as manager:
reward_events = manager.list()
scoring_queue = manager.list()
task_queue = manager.list()

# Create process pool for managed processes
processes = []

try:
# # Start checking the availability of miners at regular intervals

if shared_settings.DEPLOY_SCORING_API:
# Use multiprocessing to bypass API blocking issue
api_process = mp.Process(target=start_api, name="API_Process")
api_process.start()
processes.append(api_process)

loop_process = mp.Process(
target=create_loop_process, args=(task_queue, scoring_queue, reward_events), name="LoopProcess"
)
task_loop_process = mp.Process(
target=create_task_loop, args=(task_queue, scoring_queue), name="TaskLoopProcess"
)
loop_process.start()
task_loop_process.start()
processes.append(loop_process)
processes.append(task_loop_process)
GPUInfo.log_gpu_info()

while True:
await asyncio.sleep(10)
logger.debug("Running...")

except Exception as e:
logger.error(f"Main loop error: {e}")
raise
finally:
# Clean up processes
for process in processes:
if process.is_alive():
process.terminate()
process.join()


# The main function parses the configuration and runs the validator.
Expand Down
64 changes: 32 additions & 32 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion prompting/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def health():


async def start_scoring_api():
logger.info("Starting API...")
logger.info(f"Starting Scoring API on https://0.0.0.0:{shared_settings.SCORING_API_PORT}")
uvicorn.run(
"prompting.api.api:app", host="0.0.0.0", port=shared_settings.SCORING_API_PORT, loop="asyncio", reload=False
)
21 changes: 19 additions & 2 deletions prompting/api/scoring/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from typing import Any

from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from loguru import logger

from prompting.llms.model_zoo import ModelZoo
from prompting.rewards.scoring import task_scorer
Expand All @@ -14,10 +15,25 @@
router = APIRouter()


def validate_scoring_key(api_key: str = Header(...)):
if api_key != shared_settings.SCORING_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")


@router.post("/scoring")
async def score_response(request: Request): # , api_key_data: dict = Depends(validate_api_key)):
async def score_response(request: Request, api_key_data: dict = Depends(validate_scoring_key)):
model = None
payload: dict[str, Any] = await request.json()
body = payload.get("body")

try:
if body.get("model") is not None:
model = ModelZoo.get_model_by_id(body.get("model"))
except Exception:
logger.warning(
f"Organic request with model {body.get('model')} made but the model cannot be found in model zoo. Skipping scoring."
)
return
uid = int(payload.get("uid"))
chunks = payload.get("chunks")
llm_model = ModelZoo.get_model_by_id(model) if (model := body.get("model")) else None
Expand All @@ -39,3 +55,4 @@ async def score_response(request: Request): # , api_key_data: dict = Depends(va
step=-1,
task_id=str(uuid.uuid4()),
)
logger.info("Organic tas appended to scoring queue")
13 changes: 8 additions & 5 deletions prompting/llms/hf_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def __init__(self, model_id="hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4
Initialize Hugging Face model with reproducible settings and optimizations
"""
# Create a random seed for reproducibility
self.seed = random.randint(0, 1_000_000)
self.set_random_seeds(self.seed)
# self.seed = random.randint(0, 1_000_000)
# self.set_random_seeds(self.seed)
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
Expand Down Expand Up @@ -65,9 +65,12 @@ def generate(self, messages: list[str] | list[dict], sampling_params=None, seed=
)[0]

logger.debug(
f"PROMPT: {messages}\n\nRESPONSES: {results}\n\n"
f"SAMPLING PARAMS: {params}\n\n"
f"TIME FOR RESPONSE: {timer.elapsed_time}"
f"""REPRODUCIBLEHF WAS QUERIED:
PROMPT: {messages}\n\n
RESPONSES: {results}\n\n
SAMPLING PARAMS: {params}\n\n
SEED: {seed}\n\n
TIME FOR RESPONSE: {timer.elapsed_time}"""
)

return results if len(results) > 1 else results[0]
Expand Down
Loading
Loading