diff --git a/async_substrate_interface/substrate_interface.py b/async_substrate_interface/substrate_interface.py index 66876ff..4661a87 100644 --- a/async_substrate_interface/substrate_interface.py +++ b/async_substrate_interface/substrate_interface.py @@ -15,6 +15,7 @@ from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime +from functools import partial from hashlib import blake2b from typing import ( Optional, @@ -41,7 +42,7 @@ ExtrinsicNotFound, BlockNotFound, ) -from async_substrate_interface.utils import execute_coroutine, hex_to_bytes +from async_substrate_interface.utils import hex_to_bytes, EventLoopManager from async_substrate_interface.utils.storage import StorageKey if TYPE_CHECKING: @@ -178,7 +179,7 @@ def __init__( block_hash: Optional[str] = None, block_number: Optional[int] = None, extrinsic_idx: Optional[int] = None, - finalized=None, + finalized: bool = False, ): """ Object containing information of submitted extrinsic. Block hash where extrinsic is included is required @@ -512,7 +513,8 @@ def __init__( block_hash: Optional[str] = None, block_number: Optional[int] = None, extrinsic_idx: Optional[int] = None, - finalized=None, + finalized: bool = False, + event_loop_mgr: EventLoopManager = None, ): self._async_instance = AsyncExtrinsicReceipt( substrate, @@ -522,7 +524,7 @@ def __init__( extrinsic_idx, finalized, ) - self.event_loop = asyncio.get_event_loop() + self.event_loop_mgr = event_loop_mgr or EventLoopManager() def __getattr__(self, name): attr = getattr(self._async_instance, name) @@ -530,12 +532,12 @@ def __getattr__(self, name): if asyncio.iscoroutinefunction(attr): def sync_method(*args, **kwargs): - return self.event_loop.run_until_complete(attr(*args, **kwargs)) + return self.event_loop_mgr.run(attr(*args, **kwargs)) return sync_method elif asyncio.iscoroutine(attr): # indicates this is an async_property - return self.event_loop.run_until_complete(attr) + return self.event_loop_mgr.run(attr) else: return attr @@ -554,6 +556,7 @@ def __init__( last_key: Optional[str] = None, max_results: Optional[int] = None, ignore_decoding_errors: bool = False, + event_loop_mgr: Optional[EventLoopManager] = None, ): self.records = records self.page_size = page_size @@ -567,6 +570,7 @@ def __init__( self.ignore_decoding_errors = ignore_decoding_errors self.loading_complete = False self._buffer = iter(self.records) # Initialize the buffer with initial records + self.event_loop_mgr = event_loop_mgr async def retrieve_next_page(self, start_key) -> list: result = await self.substrate.query_map( @@ -624,19 +628,17 @@ async def __anext__(self): def __next__(self): try: - return self.substrate.event_loop.run_until_complete(self.__anext__()) + return self.event_loop_mgr.run(self.__anext__()) except StopAsyncIteration: raise StopIteration + except AttributeError: + raise AttributeError( + "This item is an async iterator. You need to iterate over it with `async for`." + ) def __getitem__(self, item): return self.records[item] - def load_all(self): - async def _load_all(): - return [item async for item in self] - - return asyncio.get_event_loop().run_until_complete(_load_all()) - @dataclass class Preprocessed: @@ -1022,13 +1024,12 @@ def __init__( auto_discover: bool = True, ss58_format: Optional[int] = None, type_registry: Optional[dict] = None, - chain_name: Optional[str] = None, + chain_name: str = "", sync_calls: bool = False, + event_loop_mgr: Optional[EventLoopManager] = None, max_retries: int = 5, retry_timeout: float = 60.0, - event_loop: Optional[asyncio.BaseEventLoop] = None, _mock: bool = False, - pre_initialize: bool = True, ): """ The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to @@ -1045,10 +1046,8 @@ def __init__( sync_calls: whether this instance is going to be called through a sync wrapper or plain max_retries: number of times to retry RPC requests before giving up retry_timeout: how to long wait since the last ping to retry the RPC request - event_loop: the event loop to use + event_loop_mgr: an EventLoopManager instance, only used in the case where `sync_calls` is `True` _mock: whether to use mock version of the subtensor interface - pre_initialize: whether to initialise the network connections at initialisation of the - AsyncSubstrateInterface object """ self.max_retries = max_retries @@ -1080,16 +1079,16 @@ def __init__( ) self.__metadata_cache = {} self.metadata_version_hex = "0x0f000000" # v15 - self.event_loop = event_loop or asyncio.get_event_loop() - self.sync_calls = sync_calls - self.extrinsic_receipt_cls = ( - AsyncExtrinsicReceipt if self.sync_calls is False else ExtrinsicReceipt - ) - if pre_initialize: - if not _mock: - self.event_loop.create_task(self.initialize()) - else: - self.reload_type_registry() + if sync_calls is True: + self.query_map_result_cls = partial( + QueryMapResult, event_loop_mgr=event_loop_mgr + ) + self.extrinsic_receipt_cls = partial( + ExtrinsicReceipt, event_loop_mgr=event_loop_mgr + ) + else: + self.query_map_result_cls = QueryMapResult + self.extrinsic_receipt_cls = AsyncExtrinsicReceipt async def __aenter__(self): await self.initialize() @@ -3773,7 +3772,7 @@ def concat_hash_len(key_hasher: str) -> int: raise item_value = None result.append([item_key, item_value]) - return QueryMapResult( + return self.query_map_result_cls( records=result, page_size=page_size, module=module, @@ -3999,12 +3998,12 @@ async def _handler(block_data: dict[str, Any]): class SyncWebsocket: - def __init__(self, websocket: "Websocket", event_loop: asyncio.AbstractEventLoop): + def __init__(self, websocket: "Websocket", event_loop_manager: EventLoopManager): self._ws = websocket - self._event_loop = event_loop + self._event_loop_mgr = event_loop_manager def close(self): - execute_coroutine(self._ws.shutdown(), event_loop=self._event_loop) + self._event_loop_mgr.run(self._ws.shutdown()) class SubstrateInterface: @@ -4013,7 +4012,7 @@ class SubstrateInterface: """ url: str - event_loop: asyncio.AbstractEventLoop + event_loop_mgr: EventLoopManager websocket: "SyncWebsocket" def __init__( @@ -4024,12 +4023,12 @@ def __init__( ss58_format: Optional[int] = None, type_registry: Optional[dict] = None, chain_name: Optional[str] = None, - event_loop: Optional[asyncio.AbstractEventLoop] = None, + event_loop_manager: Optional[EventLoopManager] = None, _mock: bool = False, substrate: Optional["AsyncSubstrateInterface"] = None, ): - event_loop = substrate.event_loop if substrate else event_loop self.url = url + self.event_loop_mgr = event_loop_manager or EventLoopManager() self._async_instance = ( AsyncSubstrateInterface( url=url, @@ -4037,16 +4036,15 @@ def __init__( auto_discover=auto_discover, ss58_format=ss58_format, type_registry=type_registry, + event_loop_mgr=self.event_loop_mgr, chain_name=chain_name, - sync_calls=True, - event_loop=event_loop, _mock=_mock, ) if not substrate else substrate ) - self.event_loop = event_loop or asyncio.get_event_loop() - self.websocket = SyncWebsocket(self._async_instance.ws, self.event_loop) + self.event_loop_mgr.run(self._async_instance.initialize()) + self.websocket = SyncWebsocket(self._async_instance.ws, self.event_loop_mgr) @property def last_block_hash(self): @@ -4057,10 +4055,10 @@ def metadata(self): return self._async_instance.metadata def __del__(self): - execute_coroutine(self._async_instance.close()) + self.event_loop_mgr.run(self._async_instance.close()) def _run(self, coroutine): - return execute_coroutine(coroutine, self.event_loop) + return self.event_loop_mgr.run(coroutine) def __getattr__(self, name): attr = getattr(self._async_instance, name) @@ -4258,3 +4256,34 @@ def create_storage_key( pallet, storage_function, params, block_hash ) ) + + +async def get_async_substrate_interface( + url: str, + use_remote_preset: bool = False, + auto_discover: bool = True, + ss58_format: Optional[int] = None, + type_registry: Optional[dict] = None, + chain_name: Optional[str] = None, + sync_calls: bool = False, + max_retries: int = 5, + retry_timeout: float = 60.0, + _mock: bool = False, +) -> "AsyncSubstrateInterface": + """ + Factory function for creating an initialized AsyncSubstrateInterface + """ + substrate = AsyncSubstrateInterface( + url, + use_remote_preset, + auto_discover, + ss58_format, + type_registry, + chain_name, + sync_calls, + max_retries, + retry_timeout, + _mock, + ) + await substrate.initialize() + return substrate diff --git a/async_substrate_interface/utils/__init__.py b/async_substrate_interface/utils/__init__.py index 48ef568..4fe53e1 100644 --- a/async_substrate_interface/utils/__init__.py +++ b/async_substrate_interface/utils/__init__.py @@ -1,9 +1,46 @@ import asyncio -from typing import Optional, TYPE_CHECKING +import threading +from typing import Optional -if TYPE_CHECKING: - from typing import Coroutine +class EventLoopManager: + """Singleton class to manage a living asyncio event loop.""" + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._init_event_loop() + return cls._instance + + def _init_event_loop(self): + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=self._start_loop, daemon=True) + self.thread.start() + + def _start_loop(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def run(self, coroutine): + while self.loop is None: + pass + future = asyncio.run_coroutine_threadsafe(coroutine, self.loop) + return future.result() # Blocks until coroutine completes + + def stop(self): + """Stop the event loop.""" + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join() + + @classmethod + def get_event_loop(cls) -> asyncio.AbstractEventLoop: + return cls().loop def hex_to_bytes(hex_str: str) -> bytes: @@ -38,24 +75,3 @@ def get_event_loop() -> asyncio.AbstractEventLoop: event_loop = asyncio.get_event_loop() asyncio.set_event_loop(event_loop) return event_loop - - -def execute_coroutine( - coroutine: "Coroutine", event_loop: asyncio.AbstractEventLoop = None -): - """ - Helper function to run an asyncio coroutine synchronously. - - Args: - coroutine (Coroutine): The coroutine to run. - event_loop (AbstractEventLoop): The event loop to use. If `None`, attempts to fetch the already-running - loop. If one is not running, a new loop is created. - - Returns: - The result of the coroutine execution. - """ - if event_loop: - event_loop = event_loop - else: - event_loop = get_event_loop() - return event_loop.run_until_complete(asyncio.wait_for(coroutine, timeout=None)) diff --git a/tests/unit_tests/test_substrate_interface.py b/tests/unit_tests/test_substrate_interface.py index 616152d..917d220 100644 --- a/tests/unit_tests/test_substrate_interface.py +++ b/tests/unit_tests/test_substrate_interface.py @@ -4,6 +4,7 @@ from async_substrate_interface.substrate_interface import ( AsyncSubstrateInterface, + get_async_substrate_interface, ScaleObj, ) @@ -12,8 +13,7 @@ async def test_invalid_url_raises_exception(): """Test that invalid URI raises an InvalidURI exception.""" with pytest.raises(InvalidURI): - async with AsyncSubstrateInterface("non_existent_entry_point"): - pass + await get_async_substrate_interface("non_existent_entry_point") def test_scale_object():