diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 612670c..502b743 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -49,6 +49,7 @@ Preprocessed, ) from async_substrate_interface.utils import hex_to_bytes, json, get_next_id +from async_substrate_interface.utils.cache import async_sql_lru_cache from async_substrate_interface.utils.decoding import ( _determine_if_old_runtime_call, _bt_decode_to_dict_or_list, @@ -1659,8 +1660,11 @@ def convert_event_data(data): events.append(convert_event_data(item)) return events - @a.lru_cache(maxsize=512) # large cache with small items + @a.lru_cache(maxsize=512) async def get_parent_block_hash(self, block_hash): + return await self._get_parent_block_hash(block_hash) + + async def _get_parent_block_hash(self, block_hash): block_header = await self.rpc_request("chain_getHeader", [block_hash]) if block_header["result"] is None: @@ -1672,16 +1676,22 @@ async def get_parent_block_hash(self, block_hash): return block_hash return parent_block_hash - @a.lru_cache(maxsize=16) # small cache with large items + @a.lru_cache(maxsize=16) async def get_block_runtime_info(self, block_hash: str) -> dict: + return await self._get_block_runtime_info(block_hash) + + async def _get_block_runtime_info(self, block_hash: str) -> dict: """ Retrieve the runtime info of given block_hash """ response = await self.rpc_request("state_getRuntimeVersion", [block_hash]) return response.get("result") - @a.lru_cache(maxsize=512) # large cache with small items + @a.lru_cache(maxsize=512) async def get_block_runtime_version_for(self, block_hash: str): + return await self._get_block_runtime_version_for(block_hash) + + async def _get_block_runtime_version_for(self, block_hash: str): """ Retrieve the runtime version of the parent of a given block_hash """ @@ -1914,7 +1924,6 @@ async def _make_rpc_request( return request_manager.get_results() - @a.lru_cache(maxsize=512) # RPC methods are unlikely to change often async def supports_rpc_method(self, name: str) -> bool: """ Check if substrate RPC supports given method @@ -1985,8 +1994,11 @@ async def rpc_request( else: raise SubstrateRequestException(result[payload_id][0]) - @a.lru_cache(maxsize=512) # block_id->block_hash does not change + @a.lru_cache(maxsize=512) async def get_block_hash(self, block_id: int) -> str: + return await self._get_block_hash(block_id) + + async def _get_block_hash(self, block_id: int) -> str: return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"] async def get_chain_head(self) -> str: @@ -3230,6 +3242,28 @@ async def _handler(block_data: dict[str, Any]): return await co +class DiskCachedAsyncSubstrateInterface(AsyncSubstrateInterface): + """ + Experimental new class that uses disk-caching in addition to memory-caching for the cached methods + """ + + @async_sql_lru_cache(maxsize=512) + async def get_parent_block_hash(self, block_hash): + return await self._get_parent_block_hash(block_hash) + + @async_sql_lru_cache(maxsize=16) + async def get_block_runtime_info(self, block_hash: str) -> dict: + return await self._get_block_runtime_info(block_hash) + + @async_sql_lru_cache(maxsize=512) + async def get_block_runtime_version_for(self, block_hash: str): + return await self._get_block_runtime_version_for(block_hash) + + @async_sql_lru_cache(maxsize=512) + async def get_block_hash(self, block_id: int) -> str: + return await self._get_block_hash(block_id) + + async def get_async_substrate_interface( url: str, use_remote_preset: bool = False, diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index d327687..daad7ce 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -1,6 +1,6 @@ +import functools import logging import random -from functools import lru_cache from hashlib import blake2b from typing import Optional, Union, Callable, Any @@ -1406,7 +1406,7 @@ def convert_event_data(data): events.append(convert_event_data(item)) return events - @lru_cache(maxsize=512) # large cache with small items + @functools.lru_cache(maxsize=512) def get_parent_block_hash(self, block_hash): block_header = self.rpc_request("chain_getHeader", [block_hash]) @@ -1419,7 +1419,7 @@ def get_parent_block_hash(self, block_hash): return block_hash return parent_block_hash - @lru_cache(maxsize=16) # small cache with large items + @functools.lru_cache(maxsize=16) def get_block_runtime_info(self, block_hash: str) -> dict: """ Retrieve the runtime info of given block_hash @@ -1427,7 +1427,7 @@ def get_block_runtime_info(self, block_hash: str) -> dict: response = self.rpc_request("state_getRuntimeVersion", [block_hash]) return response.get("result") - @lru_cache(maxsize=512) # large cache with small items + @functools.lru_cache(maxsize=512) def get_block_runtime_version_for(self, block_hash: str): """ Retrieve the runtime version of the parent of a given block_hash @@ -1655,8 +1655,7 @@ def _make_rpc_request( return request_manager.get_results() - # TODO change this logic - @lru_cache(maxsize=512) # RPC methods are unlikely to change often + @functools.lru_cache(maxsize=512) def supports_rpc_method(self, name: str) -> bool: """ Check if substrate RPC supports given method @@ -1727,7 +1726,7 @@ def rpc_request( else: raise SubstrateRequestException(result[payload_id][0]) - @lru_cache(maxsize=512) # block_id->block_hash does not change + @functools.lru_cache(maxsize=512) def get_block_hash(self, block_id: int) -> str: return self.rpc_request("chain_getBlockHash", [block_id])["result"] diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py new file mode 100644 index 0000000..ab4f457 --- /dev/null +++ b/async_substrate_interface/utils/cache.py @@ -0,0 +1,134 @@ +import functools +import os +import pickle +import sqlite3 +import asyncstdlib as a + +USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False +CACHE_LOCATION = ( + os.path.expanduser( + os.getenv("CACHE_LOCATION", "~/.cache/async-substrate-interface") + ) + if USE_CACHE + else ":memory:" +) + + +def _get_table_name(func): + """Convert "ClassName.method_name" to "ClassName_method_name""" + return func.__qualname__.replace(".", "_") + + +def _check_if_local(chain: str) -> bool: + return any([x in chain for x in ["127.0.0.1", "localhost", "0.0.0.0"]]) + + +def _create_table(c, conn, table_name): + c.execute( + f"""CREATE TABLE IF NOT EXISTS {table_name} + ( + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + key BLOB, + value BLOB, + chain TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """ + ) + c.execute( + f"""CREATE TRIGGER IF NOT EXISTS prune_rows_trigger AFTER INSERT ON {table_name} + BEGIN + DELETE FROM {table_name} + WHERE rowid IN ( + SELECT rowid FROM {table_name} + ORDER BY created_at DESC + LIMIT -1 OFFSET 500 + ); + END;""" + ) + conn.commit() + + +def _retrieve_from_cache(c, table_name, key, chain): + try: + c.execute( + f"SELECT value FROM {table_name} WHERE key=? AND chain=?", (key, chain) + ) + result = c.fetchone() + if result is not None: + return pickle.loads(result[0]) + except (pickle.PickleError, sqlite3.Error) as e: + print(f"Cache error: {str(e)}") + pass + + +def _insert_into_cache(c, conn, table_name, key, result, chain): + try: + c.execute( + f"INSERT OR REPLACE INTO {table_name} (key, value, chain) VALUES (?,?,?)", + (key, pickle.dumps(result), chain), + ) + conn.commit() + except (pickle.PickleError, sqlite3.Error) as e: + print(f"Cache error: {str(e)}") + pass + + +def sql_lru_cache(maxsize=None): + def decorator(func): + conn = sqlite3.connect(CACHE_LOCATION) + c = conn.cursor() + table_name = _get_table_name(func) + _create_table(c, conn, table_name) + + @functools.lru_cache(maxsize=maxsize) + def inner(self, *args, **kwargs): + c = conn.cursor() + key = pickle.dumps((args, kwargs)) + chain = self.url + if not (local_chain := _check_if_local(chain)) or not USE_CACHE: + result = _retrieve_from_cache(c, table_name, key, chain) + if result is not None: + return result + + # If not in DB, call func and store in DB + result = func(self, *args, **kwargs) + + if not local_chain or not USE_CACHE: + _insert_into_cache(c, conn, table_name, key, result, chain) + + return result + + return inner + + return decorator + + +def async_sql_lru_cache(maxsize=None): + def decorator(func): + conn = sqlite3.connect(CACHE_LOCATION) + c = conn.cursor() + table_name = _get_table_name(func) + _create_table(c, conn, table_name) + + @a.lru_cache(maxsize=maxsize) + async def inner(self, *args, **kwargs): + c = conn.cursor() + key = pickle.dumps((args, kwargs)) + chain = self.url + + if not (local_chain := _check_if_local(chain)) or not USE_CACHE: + result = _retrieve_from_cache(c, table_name, key, chain) + if result is not None: + return result + + # If not in DB, call func and store in DB + result = await func(self, *args, **kwargs) + if not local_chain or not USE_CACHE: + _insert_into_cache(c, conn, table_name, key, result, chain) + + return result + + return inner + + return decorator