Skip to content

Commit 001b46c

Browse files
authored
Merge pull request #2611 from opentensor/feat/roman/update-metagraph-class
Update metagraph class with `rao` stuff
2 parents 85eaa90 + 763c8af commit 001b46c

File tree

2 files changed

+119
-9
lines changed

2 files changed

+119
-9
lines changed

bittensor/core/chain_data/subnet_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from dataclasses import dataclass
7-
from typing import Optional
7+
from typing import Optional, Union
88

99
from scalecodec.utils.ss58 import ss58_encode
1010

@@ -39,7 +39,7 @@ class SubnetState:
3939
emission_history: list[list[int]]
4040

4141
@classmethod
42-
def from_vec_u8(cls, vec_u8: list[int]) -> Optional["SubnetState"]:
42+
def from_vec_u8(cls, vec_u8: Union[list[int], bytes]) -> Optional["SubnetState"]:
4343
if len(vec_u8) == 0:
4444
return None
4545
decoded = from_scale_encoding(vec_u8, ChainDataType.SubnetState)

bittensor/core/metagraph.py

Lines changed: 117 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,25 @@
99
from typing import Optional, Union
1010

1111
import numpy as np
12+
from async_substrate_interface.errors import SubstrateRequestException
1213
from numpy.typing import NDArray
1314

15+
from bittensor.core import settings
16+
from bittensor.core.chain_data import AxonInfo, SubnetState
1417
from bittensor.utils.btlogging import logging
1518
from bittensor.utils.registration import torch, use_torch
1619
from bittensor.utils.weight_utils import (
1720
convert_weight_uids_and_vals_to_tensor,
1821
convert_bond_uids_and_vals_to_tensor,
1922
convert_root_weight_uids_and_vals_to_tensor,
2023
)
21-
from bittensor.core import settings
22-
from bittensor.core.chain_data import AxonInfo
2324

2425
# For annotation purposes
2526
if typing.TYPE_CHECKING:
2627
from bittensor.core.subtensor import Subtensor
2728
from bittensor.core.async_subtensor import AsyncSubtensor
2829
from bittensor.core.chain_data import NeuronInfo, NeuronInfoLite
30+
from bittensor.utils.balance import Balance
2931

3032

3133
Tensor = Union["torch.nn.Parameter", NDArray]
@@ -219,8 +221,6 @@ class MetagraphMixin(ABC):
219221
n: Tensor
220222
neurons: list[Union["NeuronInfo", "NeuronInfoLite"]]
221223
block: Tensor
222-
stake: Tensor
223-
total_stake: Tensor
224224
ranks: Tensor
225225
trust: Tensor
226226
consensus: Tensor
@@ -234,11 +234,34 @@ class MetagraphMixin(ABC):
234234
weights: Tensor
235235
bonds: Tensor
236236
uids: Tensor
237+
alpha_stake: Tensor
238+
tao_stake: Tensor
239+
stake: Tensor
237240
axons: list[AxonInfo]
238241
chain_endpoint: Optional[str]
239242
subtensor: Optional["AsyncSubtensor"]
240243
_dtype_registry = {"int64": np.int64, "float32": np.float32, "bool": bool}
241244

245+
@property
246+
def TS(self) -> list["Balance"]:
247+
"""
248+
Represents the tao stake of each neuron in the Bittensor network.
249+
250+
Returns:
251+
list["Balance"]: The list of tao stake of each neuron in the network.
252+
"""
253+
return self.tao_stake
254+
255+
@property
256+
def AS(self) -> list["Balance"]:
257+
"""
258+
Represents the alpha stake of each neuron in the Bittensor network.
259+
260+
Returns:
261+
list["Balance"]: The list of alpha stake of each neuron in the network.
262+
"""
263+
return self.alpha_stake
264+
242265
@property
243266
def S(self) -> Union[NDArray, "torch.nn.Parameter"]:
244267
"""
@@ -251,7 +274,7 @@ def S(self) -> Union[NDArray, "torch.nn.Parameter"]:
251274
NDArray: A tensor representing the stake of each neuron in the network. Higher values signify a greater
252275
stake held by the respective neuron.
253276
"""
254-
return self.total_stake
277+
return self.stake
255278

256279
@property
257280
def R(self) -> Union[NDArray, "torch.nn.Parameter"]:
@@ -554,8 +577,6 @@ def state_dict(self):
554577
"version": self.version,
555578
"n": self.n,
556579
"block": self.block,
557-
"stake": self.stake,
558-
"total_stake": self.total_stake,
559580
"ranks": self.ranks,
560581
"trust": self.trust,
561582
"consensus": self.consensus,
@@ -571,6 +592,9 @@ def state_dict(self):
571592
"uids": self.uids,
572593
"axons": self.axons,
573594
"neurons": self.neurons,
595+
"alpha_stake": self.alpha_stake,
596+
"tao_stake": self.tao_stake,
597+
"stake": self.stake,
574598
}
575599

576600
@staticmethod
@@ -1284,6 +1308,9 @@ async def sync(
12841308
if not lite:
12851309
await self._set_weights_and_bonds(subtensor=subtensor)
12861310

1311+
# Fills in the stake associated attributes of a class instance from a chain response.
1312+
await self._get_all_stakes_from_chain(subtensor=subtensor)
1313+
12871314
async def _initialize_subtensor(
12881315
self, subtensor: "AsyncSubtensor"
12891316
) -> "AsyncSubtensor":
@@ -1448,6 +1475,46 @@ async def _process_root_weights(
14481475
)
14491476
return tensor_param
14501477

1478+
async def _get_all_stakes_from_chain(
1479+
self, subtensor: Optional["AsyncSubtensor"] = None
1480+
):
1481+
"""Fills in the stake associated attributes of a class instance from a chain response."""
1482+
try:
1483+
if not subtensor:
1484+
subtensor = self._initialize_subtensor(subtensor=subtensor)
1485+
1486+
hex_bytes_result = await subtensor.query_runtime_api(
1487+
runtime_api="SubnetInfoRuntimeApi",
1488+
method="get_subnet_state",
1489+
params=[self.netuid],
1490+
)
1491+
1492+
if hex_bytes_result is None:
1493+
logging.debug(
1494+
f"Unable to retrieve subnet state for netuid `{self.netuid}`."
1495+
)
1496+
return []
1497+
1498+
if hex_bytes_result.startswith("0x"):
1499+
bytes_result = bytes.fromhex(hex_bytes_result[2:])
1500+
else:
1501+
bytes_result = bytes.fromhex(hex_bytes_result)
1502+
1503+
subnet_state: "SubnetState" = SubnetState.from_vec_u8(bytes_result)
1504+
if self.netuid == 0:
1505+
self.total_stake = self.stake = self.tao_stake = self.alpha_stake = (
1506+
subnet_state.tao_stake
1507+
)
1508+
return subnet_state
1509+
1510+
self.alpha_stake = subnet_state.alpha_stake
1511+
self.tao_stake = [b * 0.018 for b in subnet_state.tao_stake]
1512+
self.total_stake = self.stake = subnet_state.total_stake
1513+
return subnet_state
1514+
1515+
except (SubstrateRequestException, AttributeError) as e:
1516+
logging.debug(e)
1517+
14511518

14521519
class Metagraph(NumpyOrTorch):
14531520
def __init__(
@@ -1512,6 +1579,8 @@ def sync(
15121579
15131580
metagraph.sync(block=history_block, lite=False, subtensor=subtensor)
15141581
"""
1582+
1583+
# Initialize subtensor
15151584
subtensor = self._initialize_subtensor(subtensor)
15161585

15171586
if (
@@ -1538,6 +1607,9 @@ def sync(
15381607
if not lite:
15391608
self._set_weights_and_bonds(subtensor=subtensor)
15401609

1610+
# Fills in the stake associated attributes of a class instance from a chain response.
1611+
self._get_all_stakes_from_chain(subtensor=subtensor)
1612+
15411613
def _initialize_subtensor(self, subtensor: "Subtensor") -> "Subtensor":
15421614
"""
15431615
Initializes the subtensor to be used for syncing the metagraph.
@@ -1694,6 +1766,44 @@ def _process_root_weights(
16941766
)
16951767
return tensor_param
16961768

1769+
def _get_all_stakes_from_chain(self, subtensor: Optional["Subtensor"] = None):
1770+
"""Fills in the stake associated attributes of a class instance from a chain response."""
1771+
try:
1772+
if not subtensor:
1773+
subtensor = self._initialize_subtensor()
1774+
1775+
hex_bytes_result = subtensor.query_runtime_api(
1776+
runtime_api="SubnetInfoRuntimeApi",
1777+
method="get_subnet_state",
1778+
params=[self.netuid],
1779+
)
1780+
1781+
if hex_bytes_result is None:
1782+
logging.debug(
1783+
f"Unable to retrieve subnet state for netuid `{self.netuid}`."
1784+
)
1785+
return []
1786+
1787+
if hex_bytes_result.startswith("0x"):
1788+
bytes_result = bytes.fromhex(hex_bytes_result[2:])
1789+
else:
1790+
bytes_result = bytes.fromhex(hex_bytes_result)
1791+
1792+
subnet_state: "SubnetState" = SubnetState.from_vec_u8(bytes_result)
1793+
if self.netuid == 0:
1794+
self.total_stake = self.stake = self.tao_stake = self.alpha_stake = (
1795+
subnet_state.tao_stake
1796+
)
1797+
return subnet_state
1798+
1799+
self.alpha_stake = subnet_state.alpha_stake
1800+
self.tao_stake = [b * 0.018 for b in subnet_state.tao_stake]
1801+
self.total_stake = self.stake = subnet_state.total_stake
1802+
return subnet_state
1803+
1804+
except (SubstrateRequestException, AttributeError) as e:
1805+
logging.debug(e)
1806+
16971807

16981808
async def async_metagraph(
16991809
netuid: int,

0 commit comments

Comments
 (0)