Skip to content

Commit

Permalink
Move get_network_info endpoint to the base rpc server class (#17662)
Browse files Browse the repository at this point in the history
* Add get_network_info endpoint to the base rpc_server class, remove dupe implementations from wallet/full_node
Enables also getting network info from other RPC servers (crawler, etc)

* Add get_network_info test to full_node

* Remove unused var

* Update tests/core/test_full_node_rpc.py

Co-authored-by: Kyle Altendorf <sda@fstab.net>

* Use context manager for client

* Add success: true to the assertion to be complete

* Remove the **

---------

Co-authored-by: Kyle Altendorf <sda@fstab.net>
  • Loading branch information
cmmarslender and altendky authored Mar 7, 2024
1 parent 57193ae commit 26d5a25
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 12 deletions.
6 changes: 0 additions & 6 deletions chia/rpc/full_node_rpc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def get_routes(self) -> Dict[str, Endpoint]:
# this function is just here for backwards-compatibility. It will probably
# be removed in the future
"/get_initial_freeze_period": self.get_initial_freeze_period,
"/get_network_info": self.get_network_info,
"/get_recent_signage_point_or_eos": self.get_recent_signage_point_or_eos,
# Coins
"/get_coin_records_by_puzzle_hash": self.get_coin_records_by_puzzle_hash,
Expand Down Expand Up @@ -285,11 +284,6 @@ async def get_blockchain_state(self, _: Dict[str, Any]) -> EndpointResult:
self.cached_blockchain_state = dict(response["blockchain_state"])
return response

async def get_network_info(self, _: Dict[str, Any]) -> EndpointResult:
network_name = self.service.config["selected_network"]
address_prefix = self.service.config["network_overrides"]["config"][network_name]["address_prefix"]
return {"network_name": network_name, "network_prefix": address_prefix}

async def get_recent_signage_point_or_eos(self, request: Dict[str, Any]) -> EndpointResult:
if "sp_hash" not in request:
challenge_hash: bytes32 = bytes32.from_hexstr(request["challenge_hash"])
Expand Down
8 changes: 8 additions & 0 deletions chia/rpc/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class RpcServer(Generic[_T_RpcApiProtocol]):
service_name: str
ssl_context: SSLContext
ssl_client_context: SSLContext
net_config: Dict[str, Any]
webserver: Optional[WebServer] = None
daemon_heartbeat: int = 300
daemon_connection_task: Optional[asyncio.Task[None]] = None
Expand Down Expand Up @@ -163,6 +164,7 @@ def create(
service_name,
ssl_context,
ssl_client_context,
net_config,
daemon_heartbeat=daemon_heartbeat,
prefer_ipv6=prefer_ipv6,
)
Expand Down Expand Up @@ -235,6 +237,7 @@ def listen_port(self) -> uint16:
def _get_routes(self) -> Dict[str, Endpoint]:
return {
**self.rpc_api.get_routes(),
"/get_network_info": self.get_network_info,
"/get_connections": self.get_connections,
"/open_connection": self.open_connection,
"/close_connection": self.close_connection,
Expand All @@ -249,6 +252,11 @@ async def get_routes(self, request: Dict[str, Any]) -> EndpointResult:
"routes": list(self._get_routes().keys()),
}

async def get_network_info(self, _: Dict[str, Any]) -> EndpointResult:
network_name = self.net_config["selected_network"]
address_prefix = self.net_config["network_overrides"]["config"][network_name]["address_prefix"]
return {"network_name": network_name, "network_prefix": address_prefix}

async def get_connections(self, request: Dict[str, Any]) -> EndpointResult:
request_node_type: Optional[NodeType] = None
if "node_type" in request:
Expand Down
6 changes: 0 additions & 6 deletions chia/rpc/wallet_rpc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def get_routes(self) -> Dict[str, Endpoint]:
# this function is just here for backwards-compatibility. It will probably
# be removed in the future
"/get_initial_freeze_period": self.get_initial_freeze_period,
"/get_network_info": self.get_network_info,
# Wallet management
"/get_wallets": self.get_wallets,
"/create_new_wallet": self.create_new_wallet,
Expand Down Expand Up @@ -594,11 +593,6 @@ async def get_height_info(self, request: Dict[str, Any]) -> EndpointResult:
height = await self.service.wallet_state_manager.blockchain.get_finished_sync_up_to()
return {"height": height}

async def get_network_info(self, request: Dict[str, Any]) -> EndpointResult:
network_name = self.service.config["selected_network"]
address_prefix = self.service.config["network_overrides"]["config"][network_name]["address_prefix"]
return {"network_name": network_name, "network_prefix": address_prefix}

async def push_tx(self, request: Dict[str, Any]) -> EndpointResult:
nodes = self.service.server.get_connections(NodeType.FULL_NODE)
if len(nodes) == 0:
Expand Down
16 changes: 16 additions & 0 deletions tests/core/test_full_node_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,22 @@ async def test_signage_points(two_nodes_sim_and_wallets_services, empty_blockcha
await client.await_closed()


@pytest.mark.anyio
async def test_get_network_info(one_wallet_and_one_simulator_services, self_hostname):
nodes, _, bt = one_wallet_and_one_simulator_services
(full_node_service_1,) = nodes

async with FullNodeRpcClient.create_as_context(
self_hostname,
full_node_service_1.rpc_server.listen_port,
full_node_service_1.root_path,
full_node_service_1.config,
) as client:
await validate_get_routes(client, full_node_service_1.rpc_server.rpc_api)
network_info = await client.fetch("get_network_info", {})
assert network_info == {"network_name": "testnet0", "network_prefix": "txch", "success": True}


@pytest.mark.anyio
async def test_get_blockchain_state(one_wallet_and_one_simulator_services, self_hostname):
num_blocks = 5
Expand Down
1 change: 1 addition & 0 deletions tests/util/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ async def validate_get_routes(client: RpcClient, api: RpcApiProtocol) -> None:
routes_api = list(api.get_routes().keys())
# TODO: avoid duplication of RpcServer.get_routes()
routes_server = [
"/get_network_info",
"/get_connections",
"/open_connection",
"/close_connection",
Expand Down

0 comments on commit 26d5a25

Please sign in to comment.