Skip to content

Commit

Permalink
fix: fixed linter issue, and commitizen
Browse files Browse the repository at this point in the history
  • Loading branch information
codebender37 authored and karootplx committed Dec 9, 2024
1 parent 23d085a commit 553ad66
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 37 deletions.
4 changes: 4 additions & 0 deletions commons/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ class ObjectManager:
def get_miner(cls):
if get_config().simulation:
from simulator.miner import MinerSim

if cls._miner is None:
cls._miner = MinerSim()
else:
from neurons.miner import Miner

if cls._miner is None:
cls._miner = Miner()
return cls._miner
Expand All @@ -23,10 +25,12 @@ def get_miner(cls):
def get_validator(cls):
if get_config().simulation:
from simulator.validator import ValidatorSim

if cls._validator is None:
cls._validator = ValidatorSim()
else:
from neurons.validator import Validator

if cls._validator is None:
cls._validator = Validator()
return cls._validator
Expand Down
2 changes: 1 addition & 1 deletion dojo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ def get_dojo_api_base_url() -> str:
if base_url is None:
raise ValueError("DOJO_API_BASE_URL is not set in the environment.")

return base_url
return base_url
2 changes: 1 addition & 1 deletion entrypoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ if [ "$1" = 'validator' ]; then
--neuron.type validator \
--wandb.project_name ${WANDB_PROJECT_NAME} \
${EXTRA_ARGS}
fi
fi
33 changes: 17 additions & 16 deletions simulator/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ def __init__(self):
host = os.getenv("REDIS_HOST", "localhost")
port = int(os.getenv("REDIS_PORT", 6379))
self.redis_client = redis.Redis(
host=host,
port=port,
db=0,
decode_responses=True
host=host, port=port, db=0, decode_responses=True
)
logger.info("Redis connection established")

Expand All @@ -42,12 +39,14 @@ def __init__(self):
def _configure_simulation(self):
"""Configure simulation parameters with environment variables or defaults."""
self.response_behaviors = {
'normal': float(os.getenv("SIM_NORMAL_RESP_PROB", 0.8)),
'no_response': float(os.getenv("SIM_NO_RESP_PROB", 0.1)),
'timeout': float(os.getenv("SIM_TIMEOUT_PROB", 0.1))
"normal": float(os.getenv("SIM_NORMAL_RESP_PROB", 0.8)),
"no_response": float(os.getenv("SIM_NO_RESP_PROB", 0.1)),
"timeout": float(os.getenv("SIM_TIMEOUT_PROB", 0.1)),
}

async def forward_feedback_request(self, synapse: FeedbackRequest) -> FeedbackRequest:
async def forward_feedback_request(
self, synapse: FeedbackRequest
) -> FeedbackRequest:
try:
# Validate that synapse, dendrite, dendrite.hotkey, and response are not None
if not synapse or not synapse.dendrite or not synapse.dendrite.hotkey:
Expand All @@ -69,7 +68,7 @@ async def forward_feedback_request(self, synapse: FeedbackRequest) -> FeedbackRe
self.redis_client.set(
redis_key,
new_synapse.model_dump_json(),
ex=86400 # expire after 24 hours
ex=86400, # expire after 24 hours
)
logger.info(f"Stored feedback request {synapse.request_id}")

Expand All @@ -81,7 +80,9 @@ async def forward_feedback_request(self, synapse: FeedbackRequest) -> FeedbackRe
traceback.print_exc()
return synapse

async def forward_task_result_request(self, synapse: TaskResultRequest) -> TaskResultRequest | None:
async def forward_task_result_request(
self, synapse: TaskResultRequest
) -> TaskResultRequest | None:
try:
logger.info(f"Received TaskResultRequest for task id: {synapse.task_id}")
if not synapse or not synapse.task_id:
Expand All @@ -91,9 +92,9 @@ async def forward_task_result_request(self, synapse: TaskResultRequest) -> TaskR
# Simulate different response behaviors
behavior = self._get_response_behavior()

if behavior in ['no_response', 'timeout']:
if behavior in ["no_response", "timeout"]:
logger.debug(f"Simulating {behavior} for task {synapse.task_id}")
if behavior == 'timeout':
if behavior == "timeout":
await asyncio.sleep(30)
return None

Expand All @@ -113,17 +114,17 @@ async def forward_task_result_request(self, synapse: TaskResultRequest) -> TaskR
for criteria_type in feedback_request.criteria_types:
result = Result(
type=criteria_type.type,
value=self._generate_scores(feedback_request.ground_truth)
value=self._generate_scores(feedback_request.ground_truth),
)

task_result = TaskResult(
id=get_new_uuid(),
status='COMPLETED',
status="COMPLETED",
created_at=current_time,
updated_at=current_time,
result_data=[result],
worker_id=get_new_uuid(),
task_id=synapse.task_id
task_id=synapse.task_id,
)
task_results.append(task_result)

Expand All @@ -144,7 +145,7 @@ def _get_response_behavior(self) -> str:
"""Determine the response behavior based on configured probabilities."""
return random.choices(
list(self.response_behaviors.keys()),
weights=list(self.response_behaviors.values())
weights=list(self.response_behaviors.values()),
)[0]

def _generate_scores(self, ground_truth: dict) -> dict:
Expand Down
42 changes: 23 additions & 19 deletions simulator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ def block(self):
except BrokenPipeError:
self._block_check_attempts += 1
if self._block_check_attempts >= self.MAX_BLOCK_CHECK_ATTEMPTS:
logger.error("Multiple failed attempts to get block number, attempting reconnection")
if asyncio.get_event_loop().run_until_complete(self._try_reconnect_subtensor()):
logger.error(
"Multiple failed attempts to get block number, attempting reconnection"
)
if asyncio.get_event_loop().run_until_complete(
self._try_reconnect_subtensor()
):
return self.block

return self._last_block if self._last_block is not None else 0
Expand All @@ -60,8 +64,8 @@ def block(self):
def check_registered(self):
new_subtensor = bt.subtensor(self.subtensor.config)
if not new_subtensor.is_hotkey_registered(
netuid=self.config.netuid,
hotkey_ss58=self.wallet.hotkey.ss58_address,
netuid=self.config.netuid,
hotkey_ss58=self.wallet.hotkey.ss58_address,
):
logger.error(
f"Wallet: {self.wallet} is not registered on netuid {self.config.netuid}."
Expand All @@ -70,9 +74,9 @@ def check_registered(self):
exit()

async def send_request(
self,
synapse: FeedbackRequest | None = None,
external_user: bool = False,
self,
synapse: FeedbackRequest | None = None,
external_user: bool = False,
):
start = get_epoch_time()
# typically the request may come from an external source however,
Expand All @@ -89,7 +93,7 @@ async def send_request(
self.metagraph.axons[uid]
for uid in sel_miner_uids
if self.metagraph.axons[uid].hotkey.casefold()
!= self.wallet.hotkey.ss58_address.casefold()
!= self.wallet.hotkey.ss58_address.casefold()
]
if not len(axons):
logger.warning("🤷 No axons to query ... skipping")
Expand Down Expand Up @@ -135,7 +139,7 @@ async def send_request(
prompt=data.prompt,
completion_responses=data.responses,
expire_at=expire_at,
ground_truth=data.ground_truth # Added ground truth!!!!!
ground_truth=data.ground_truth, # Added ground truth!!!!!
)
elif external_user:
obfuscated_model_to_model = self.obfuscate_model_names(
Expand Down Expand Up @@ -217,24 +221,24 @@ async def send_request(

@staticmethod
async def _send_shuffled_requests(
dendrite: bt.dendrite, axons: List[bt.AxonInfo], synapse: FeedbackRequest
dendrite: bt.dendrite, axons: List[bt.AxonInfo], synapse: FeedbackRequest
) -> list[FeedbackRequest]:
"""Send the same request to all miners without shuffling the order.
WARNING: This should only be used for testing/debugging as it could allow miners to game the system.
WARNING: This should only be used for testing/debugging as it could allow miners to game the system.
Args:
dendrite (bt.dendrite): Communication channel to send requests
axons (List[bt.AxonInfo]): List of miner endpoints
synapse (FeedbackRequest): The feedback request to send
Args:
dendrite (bt.dendrite): Communication channel to send requests
axons (List[bt.AxonInfo]): List of miner endpoints
synapse (FeedbackRequest): The feedback request to send
Returns:
list[FeedbackRequest]: List of miner responses
"""
Returns:
list[FeedbackRequest]: List of miner responses
"""
all_responses = []
batch_size = 10

for i in range(0, len(axons), batch_size):
batch_axons = axons[i: i + batch_size]
batch_axons = axons[i : i + batch_size]
tasks = []

for axon in batch_axons:
Expand Down

0 comments on commit 553ad66

Please sign in to comment.