Skip to content

Commit

Permalink
fix: validator event loop (#130)
Browse files Browse the repository at this point in the history
* chore: fastapi service naming

* test: testing subscribe to block headers

* fix: get latest git tag

* fix: remove async code in block property

* fix: lazy get git tag

* chore: add warning

* chore: move test script

* perf: add exponential retries for subscriber

* chore: move script

* chore: remove unused variables

---------

Co-authored-by: karootplx <karoo@tensorplex.ai>
  • Loading branch information
jarvis8x7b and karootplx authored Feb 15, 2025
1 parent f10b867 commit 53082e7
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 41 deletions.
4 changes: 2 additions & 2 deletions auto_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bittensor.utils.btlogging import logging as logger

from commons.utils import datetime_to_iso8601_str
from dojo import __version__
from dojo import get_latest_git_tag
from dojo.utils.config import source_dotenv

source_dotenv()
Expand Down Expand Up @@ -214,7 +214,7 @@ def restart_docker(service_name):


def get_current_version():
version = __version__
version = get_latest_git_tag()
logger.debug(f"Current version: {version}")
return version

Expand Down
195 changes: 195 additions & 0 deletions commons/block_subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import asyncio
from datetime import datetime
from typing import Any, Awaitable, Callable

from bittensor.core.async_subtensor import AsyncSubstrateInterface
from bittensor.core.subtensor import SubstrateRequestException
from loguru import logger

from commons.objects import ObjectManager

BLOCK_TIME = 12


class SubscriptionWatchdog:
def __init__(self, max_block_interval: float):
self.last_block_time = datetime.now()
self.max_block_interval = max_block_interval
self.is_healthy = True

def update(self):
"""Updates the last block time and marks the subscription as healthy."""
self.last_block_time = datetime.now()
self.is_healthy = True

def check_health(self) -> bool:
"""Checks if the subscription is healthy by comparing the time since the last block to the max block interval."""
time_since_last_block = (datetime.now() - self.last_block_time).total_seconds()
self.is_healthy = time_since_last_block <= self.max_block_interval
return self.is_healthy


async def monitor_subscription(
watchdog: SubscriptionWatchdog, max_block_interval: float
):
"""Monitors the health of the block subscription by checking the time since the last block.
Runs continuously in the background, checking every 10 seconds if a new block has been
received within the maximum allowed interval. If no blocks are received for longer than
the max interval, raises a ConnectionError.
Raises:
ConnectionError: When no new blocks have been received for longer than max_block_interval seconds,
indicating the subscription has likely failed.
"""
while True:
await asyncio.sleep(BLOCK_TIME)
time_since_last = (datetime.now() - watchdog.last_block_time).total_seconds()

if not watchdog.check_health():
logger.warning(
f"No blocks received for {time_since_last:.1f} seconds! (max allowed: {max_block_interval})"
)
# Create a specific exception type for this case
raise ConnectionError(
f"Subscription watchdog timeout - no blocks for {time_since_last:.1f} seconds"
)
else:
logger.debug(
f"Subscription is healthy - last block {time_since_last:.1f} seconds ago"
)


async def start_block_subscriber(
callbacks: list[Callable[..., Awaitable[Any]]],
url: str = ObjectManager.get_config().subtensor.chain_endpoint, # type: ignore
retry_delay: float = 5.0,
max_block_interval: float = 2 * BLOCK_TIME,
max_retries: int | None = None,
):
"""Starts a block subscriber that monitors the health of the block subscription.
Args:
callback (Callable[..., Awaitable[Any]]): The callback function to call when a block is received.
url (str, optional): The URL of the substrate node. Defaults to ObjectManager.get_config().subtensor.chain_endpoint.
retry_delay (float, optional): The delay between retries. Defaults to 5.0.
max_retries (int | None, optional): The maximum number of retries. Defaults to None.
max_block_interval (float, optional): The maximum interval between blocks. Defaults to 2*BLOCK_TIME.
Raises:
ConnectionError: When no new blocks have been received for longer than max_block_interval seconds,
indicating the subscription has likely failed.
"""
watchdog = SubscriptionWatchdog(max_block_interval)

retry_count = 0

async def wrapped_callback(*args, **kwargs):
"""Wraps the original callback function to provide additional functionality.
Updates the watchdog timer and resets retry count on successful block processing.
Forwards all arguments to the original callback function.
Args:
*args: Variable positional arguments to pass to the callback
**kwargs: Variable keyword arguments to pass to the callback
"""
nonlocal retry_count
retry_count = 0
watchdog.update()

# execute all callbacks
for callback in callbacks:
await callback(*args, **kwargs)

while True:
try:
# Connect to the substrate node
async with AsyncSubstrateInterface(url=url) as substrate:
monitor_task = asyncio.create_task(
monitor_subscription(watchdog, max_block_interval)
)
try:
logger.info("Subscribing to block headers...")
# Create the subscription task
subscription_task = asyncio.create_task(
substrate.subscribe_block_headers(
subscription_handler=wrapped_callback, finalized_only=True
)
)

# Wait for either task to complete (or fail)
done, pending = await asyncio.wait(
[monitor_task, subscription_task],
return_when=asyncio.FIRST_COMPLETED,
)

# Cancel the remaining task
for task in pending:
task.cancel()

# Check if monitor_task raised an exception
if monitor_task in done:
monitor_task.result() # This will raise the exception if there was one

except ConnectionError:
logger.error("Watchdog detected subscription failure")
raise
except Exception as subscription_error:
logger.error(f"Subscription failed: {subscription_error}")
raise
finally:
# Clean up tasks
for task in [monitor_task]:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

except KeyboardInterrupt:
logger.info("\nSubscription ended by user")
raise

except (SubstrateRequestException, Exception) as e:
logger.error(f"Error occurred: {e}")

retry_count += 1
if max_retries is not None and retry_count >= max_retries:
logger.error(
f"Max retries ({max_retries}) reached. Stopping subscription."
)
raise

# Calculate exponential delay with base delay and retry count
current_delay = retry_delay * (2 ** (retry_count - 1))

logger.error(f"Error occurred: {e}")
logger.info(
f"Attempting to resubscribe in {current_delay} seconds... (attempt {retry_count})"
)
await asyncio.sleep(current_delay)
continue


async def your_callback(block: dict):
logger.success(f"Received block: {block}")


async def main():
try:
# Will raise an exception if no blocks received for 60 seconds
await start_block_subscriber(
[your_callback],
max_block_interval=12,
retry_delay=5.0,
)
except KeyboardInterrupt:
logger.info("Shutting down...")
except Exception as e:
logger.error(f"Subscription failed: {e}")


if __name__ == "__main__":
asyncio.run(main())
46 changes: 26 additions & 20 deletions dojo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import os
import subprocess

from git import Repo

from dojo.utils.config import get_config, source_dotenv

source_dotenv()


def get_latest_git_tag():
try:
# Get the latest git tag
latest_tag = (
subprocess.check_output(["git", "describe", "--tags", "--abbrev=0"])
.strip()
.decode("utf-8")
)
return latest_tag.lstrip("v")
except subprocess.CalledProcessError as e:
print(f"Error getting the latest Git tag: {e}")
raise RuntimeError("Failed to get latest Git tag")
def get_latest_git_tag(repo_path="."):
repo = Repo(repo_path)
tags = sorted(repo.tags, key=lambda t: t.commit.committed_date)
return str(tags[-1]).lstrip("v") if tags else None


def get_latest_remote_tag(repo_path="."):
repo = Repo(repo_path)
remote_tags = repo.git.ls_remote("--tags", "--sort=-v:refname", "origin").split(
"\n"
)
if remote_tags and remote_tags[0]:
return remote_tags[0].split("refs/tags/")[-1]
return None


def get_commit_hash():
Expand All @@ -34,14 +38,16 @@ def get_commit_hash():
raise RuntimeError("Failed to get latest Git commit hash")


# Define the version of the template module.
__version__ = get_latest_git_tag()
version_split = __version__.split(".")
__spec_version__ = (
(1000 * int(version_split[0]))
+ (10 * int(version_split[1]))
+ (1 * int(version_split[2]))
)
def get_spec_version():
latest_tag = get_latest_git_tag()
if latest_tag is None:
raise ValueError("No Git tag found")
version_split = latest_tag.split(".")
return (
(1000 * int(version_split[0]))
+ (10 * int(version_split[1]))
+ (1 * int(version_split[2]))
)


VALIDATOR_MIN_STAKE = int(os.getenv("VALIDATOR_MIN_STAKE", "20000"))
Expand Down
4 changes: 2 additions & 2 deletions dojo/base/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from commons.objects import ObjectManager
from commons.utils import initialise, ttl_get_block
from dojo import __spec_version__ as spec_version
from dojo import get_spec_version


class BaseNeuron(ABC):
Expand All @@ -18,7 +18,7 @@ class BaseNeuron(ABC):
subtensor: bt.subtensor
wallet: bt.wallet
metagraph: bt.metagraph
spec_version: int = spec_version
spec_version: int = get_spec_version()

@property
def block(self):
Expand Down
2 changes: 1 addition & 1 deletion entrypoints/validator_api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from commons.objects import ObjectManager
from dojo import VALIDATOR_MIN_STAKE

app = FastAPI(title="Dataset Upload Service")
app = FastAPI(title="Validator API Service")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand Down
4 changes: 4 additions & 0 deletions main_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi.middleware.cors import CORSMiddleware

from commons.api.middleware import LimitContentLengthMiddleware
from commons.block_subscriber import start_block_subscriber
from commons.dataset.synthetic import SyntheticAPI
from commons.objects import ObjectManager
from database.client import connect_db, disconnect_db
Expand Down Expand Up @@ -53,6 +54,9 @@ async def main():
asyncio.create_task(validator.run()),
asyncio.create_task(validator.update_score_and_send_feedback()),
asyncio.create_task(validator.send_heartbeats()),
asyncio.create_task(
start_block_subscriber(callbacks=[validator.block_headers_callback])
),
]

await server.serve()
Expand Down
33 changes: 17 additions & 16 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@
get_new_uuid,
initialise,
set_expire_time,
ttl_get_block,
)
from dojo import __spec_version__
from dojo import get_latest_git_tag, get_latest_remote_tag, get_spec_version
from dojo.protocol import (
CompletionResponse,
CriteriaType,
Expand All @@ -63,6 +62,15 @@
ObfuscatedModelMap: TypeAlias = Dict[str, str]


latest_local = get_latest_git_tag()
latest_remote = get_latest_remote_tag()
if latest_local != latest_remote:
logger.warn("Your repository is not up to date, and may fail to set weights.")
logger.warn(
f"latest local version: {latest_local}\nlatest remote version: {latest_remote}"
)


class Validator:
_should_exit: bool = False
_scores_alock = asyncio.Lock()
Expand All @@ -74,7 +82,7 @@ class Validator:
subtensor: bt.subtensor
wallet: bt.wallet # type: ignore
metagraph: bt.metagraph
spec_version: int = __spec_version__
spec_version: int = get_spec_version()

def __init__(self):
self.MAX_BLOCK_CHECK_ATTEMPTS = 3
Expand Down Expand Up @@ -506,19 +514,7 @@ async def sync(self):

@property
def block(self):
try:
if not self.loop.run_until_complete(self._ensure_subtensor_connection()):
logger.warning(
"Subtensor connection failed - returning last known block"
)
return self._last_block if self._last_block is not None else 0

self._last_block = ttl_get_block(self.subtensor)
self._block_check_attempts = 0
return self._last_block
except Exception as e:
logger.error(f"Error getting block number: {e}")
return self._last_block if self._last_block is not None else 0
return self._last_block

async def _try_reconnect_subtensor(self):
self._block_check_attempts += 1
Expand Down Expand Up @@ -1420,3 +1416,8 @@ async def _get_dojo_task_scores_and_gt(
}
)
return hotkey_to_dojo_task_scores_and_gt

async def block_headers_callback(self, block: dict):
logger.debug(f"Received block headers{block}")
block_number = int(block.get("header", {}).get("number"))
self._last_block = block_number

0 comments on commit 53082e7

Please sign in to comment.