Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add signature v2 format #983

Merged
merged 7 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bittensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
version_split = __version__.split(".")
__version_as_int__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2]))

__new_signature_version__ = 360

# Turn off rich console locals trace.
from rich.traceback import install
Expand Down
101 changes: 65 additions & 36 deletions bittensor/_axon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,9 @@ def __new__(
if thread_pool == None:
thread_pool = futures.ThreadPoolExecutor( max_workers = config.axon.max_workers )
if server == None:
receiver_hotkey = wallet.hotkey.ss58_address
server = grpc.server( thread_pool,
interceptors=(AuthInterceptor(blacklist=blacklist),),
interceptors=(AuthInterceptor(receiver_hotkey=receiver_hotkey, blacklist=blacklist),),
maximum_concurrent_rpcs = config.axon.maximum_concurrent_rpcs,
options = [('grpc.keepalive_time_ms', 100000),
('grpc.keepalive_timeout_ms', 500000)]
Expand Down Expand Up @@ -341,22 +342,26 @@ def check_forward_callback( forward_callback:Callable, synapses:list = []):
class AuthInterceptor(grpc.ServerInterceptor):
"""Creates a new server interceptor that authenticates incoming messages from passed arguments."""

def __init__(self, key: str = "Bittensor", blacklist: List = []):
def __init__(
self,
receiver_hotkey: str,
blacklist: Callable = None,
):
r"""Creates a new server interceptor that authenticates incoming messages from passed arguments.
Args:
key (str, `optional`):
key for authentication header in the metadata (default = Bittensor)
receiver_hotkey(str):
the SS58 address of the hotkey which should be targeted by RPCs
black_list (Function, `optional`):
black list function that prevents certain pubkeys from sending messages
"""
super().__init__()
self.auth_header_value = key
self.nonces = {}
self.blacklist = blacklist
self.receiver_hotkey = receiver_hotkey

def parse_legacy_signature(
self, signature: str
) -> Union[Tuple[int, str, str, str], None]:
) -> Union[Tuple[int, str, str, str, int], None]:
r"""Attempts to parse a signature using the legacy format, using `bitxx` as a separator"""
parts = signature.split("bitxx")
if len(parts) < 4:
Expand All @@ -367,52 +372,71 @@ def parse_legacy_signature(
except ValueError:
return None
receptor_uuid, parts = parts[-1], parts[:-1]
message, parts = parts[-1], parts[:-1]
pubkey = "".join(parts)
return (nonce, pubkey, message, receptor_uuid)
signature, parts = parts[-1], parts[:-1]
sender_hotkey = "".join(parts)
return (nonce, sender_hotkey, signature, receptor_uuid, 1)

def parse_signature(self, metadata: Dict[str, str]) -> Tuple[int, str, str, str]:
def parse_signature_v2(
self, signature: str
) -> Union[Tuple[int, str, str, str, int], None]:
r"""Attempts to parse a signature using the v2 format"""
parts = signature.split(".")
if len(parts) != 4:
return None
try:
nonce = int(parts[0])
except ValueError:
return None
sender_hotkey = parts[1]
signature = parts[2]
receptor_uuid = parts[3]
return (nonce, sender_hotkey, signature, receptor_uuid, 2)

def parse_signature(
self, metadata: Dict[str, str]
) -> Tuple[int, str, str, str, int]:
r"""Attempts to parse a signature from the metadata"""
signature = metadata.get("bittensor-signature")
if signature is None:
raise Exception("Request signature missing")
parts = self.parse_legacy_signature(signature)
if parts is not None:
return parts
for parser in [self.parse_signature_v2, self.parse_legacy_signature]:
parts = parser(signature)
if parts is not None:
return parts
raise Exception("Unknown signature format")

def check_signature(
self, nonce: int, pubkey: str, signature: str, receptor_uuid: str
self,
nonce: int,
sender_hotkey: str,
signature: str,
receptor_uuid: str,
format: int,
):
r"""verification of signature in metadata. Uses the pubkey and nonce"""
keypair = Keypair(ss58_address=pubkey)
keypair = Keypair(ss58_address=sender_hotkey)
# Build the expected message which was used to build the signature.
message = f"{nonce}{pubkey}{receptor_uuid}"
if format == 2:
message = f"{nonce}.{sender_hotkey}.{self.receiver_hotkey}.{receptor_uuid}"
elif format == 1:
message = f"{nonce}{sender_hotkey}{receptor_uuid}"
else:
raise Exception("Invalid signature version")
# Build the key which uniquely identifies the endpoint that has signed
# the message.
endpoint_key = f"{pubkey}:{receptor_uuid}"
endpoint_key = f"{sender_hotkey}:{receptor_uuid}"

if endpoint_key in self.nonces.keys():
previous_nonce = self.nonces[endpoint_key]
# Nonces must be strictly monotonic over time.
if nonce - previous_nonce <= -10:
if nonce <= previous_nonce:
raise Exception("Nonce is too small")
if not keypair.verify(message, signature):
raise Exception("Signature mismatch")
self.nonces[endpoint_key] = nonce
return

if not keypair.verify(message, signature):
raise Exception("Signature mismatch")
self.nonces[endpoint_key] = nonce

def version_checking(self, metadata: Dict[str, str]):
r"""Checks the header and version in the metadata"""
provided_value = metadata.get("rpc-auth-header")
if provided_value is None or provided_value != self.auth_header_value:
raise Exception("Unexpected caller metadata")

def black_list_checking(self, pubkey: str, method: str):
def black_list_checking(self, hotkey: str, method: str):
r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey"""
if self.blacklist == None:
return
Expand All @@ -424,7 +448,7 @@ def black_list_checking(self, pubkey: str, method: str):
if request_type is None:
raise Exception("Unknown request type")

if self.blacklist(pubkey, request_type):
if self.blacklist(hotkey, request_type):
raise Exception("Request type is blacklisted")

def intercept_service(self, continuation, handler_call_details):
Expand All @@ -433,16 +457,21 @@ def intercept_service(self, continuation, handler_call_details):
metadata = dict(handler_call_details.invocation_metadata)

try:
# version checking
self.version_checking(metadata)

(nonce, pubkey, signature, receptor_uuid) = self.parse_signature(metadata)
(
nonce,
sender_hotkey,
signature,
receptor_uuid,
signature_format,
) = self.parse_signature(metadata)

# signature checking
self.check_signature(nonce, pubkey, signature, receptor_uuid)
self.check_signature(
nonce, sender_hotkey, signature, receptor_uuid, signature_format
)

# blacklist checking
self.black_list_checking(pubkey, method)
self.black_list_checking(sender_hotkey, method)

return continuation(handler_call_details)

Expand Down
28 changes: 20 additions & 8 deletions bittensor/_receptor/receptor_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,32 @@ def __del__ ( self ):
def __exit__ ( self ):
self.__del__()

def sign ( self ):
def sign_v1( self ):
r""" Uses the wallet pubkey to sign a message containing the pubkey and the time
"""
nounce = self.nounce()
message = str(nounce) + str(self.wallet.hotkey.ss58_address) + str(self.receptor_uid)
nonce = self.nonce()
message = str(nonce) + str(self.wallet.hotkey.ss58_address) + str(self.receptor_uid)
spliter = 'bitxx'
signature = spliter.join([ str(nounce), str(self.wallet.hotkey.ss58_address), "0x" + self.wallet.hotkey.sign(message).hex(), str(self.receptor_uid) ])
signature = spliter.join([ str(nonce), str(self.wallet.hotkey.ss58_address), "0x" + self.wallet.hotkey.sign(message).hex(), str(self.receptor_uid) ])
return signature

def nounce ( self ):

def sign_v2(self):
nonce = f"{self.nonce()}"
sender_hotkey = self.wallet.hotkey.ss58_address
receiver_hotkey = self.endpoint.hotkey
message = f"{nonce}.{sender_hotkey}.{receiver_hotkey}.{self.receptor_uid}"
signature = f"0x{self.wallet.hotkey.sign(message).hex()}"
return ".".join([nonce, sender_hotkey, signature, self.receptor_uid])

def sign(self):
if self.endpoint.version >= bittensor.__new_signature_version__:
return self.sign_v2()
return self.sign_v1()

def nonce ( self ):
r"""creates a string representation of the time
"""
nounce = int(clock.time() * 1000)
return nounce
return clock.monotonic_ns()

def state ( self ):
try:
Expand Down
31 changes: 18 additions & 13 deletions bittensor/_subtensor/subtensor_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,24 +671,29 @@ def serve (
# Decrypt hotkey
wallet.hotkey

with bittensor.__console__.status(":satellite: Checking Axon..."):
neuron = self.neuron_for_pubkey( wallet.hotkey.ss58_address )
if not neuron.is_null and neuron.ip == net.ip_to_int(ip) and neuron.port == port:
bittensor.__console__.print(":white_heavy_check_mark: [green]Already Served[/green]\n [bold white]ip: {}\n port: {}\n modality: {}\n hotkey: {}\n coldkey: {}[/bold white]".format(ip, port, modality, wallet.hotkey.ss58_address, wallet.coldkeypub.ss58_address))
return True

ip_as_int = net.ip_to_int(ip)
ip_version = net.ip_version(ip)

# TODO(const): subscribe with version too.
params = {
'version': bittensor.__version_as_int__,
'ip': ip_as_int,
'port': port,
'ip_type': ip_version,
'ip': net.ip_to_int(ip),
'port': port,
'ip_type': net.ip_version(ip),
'modality': modality,
'coldkey': wallet.coldkeypub.ss58_address,
}

with bittensor.__console__.status(":satellite: Checking Axon..."):
neuron = self.neuron_for_pubkey( wallet.hotkey.ss58_address )
neuron_up_to_date = not neuron.is_null and params == {
'version': neuron.version,
'ip': neuron.ip,
'port': neuron.port,
'ip_type': neuron.ip_version,
'modality': neuron.modality,
'coldkey': neuron.coldkey
}
if neuron_up_to_date:
bittensor.__console__.print(":white_heavy_check_mark: [green]Already Served[/green]\n [bold white]ip: {}\n port: {}\n modality: {}\n hotkey: {}\n coldkey: {}[/bold white]".format(ip, port, modality, wallet.hotkey.ss58_address, wallet.coldkeypub.ss58_address))
return True

if prompt:
if not Confirm.ask("Do you want to serve axon:\n [bold white]ip: {}\n port: {}\n modality: {}\n hotkey: {}\n coldkey: {}[/bold white]".format(ip, port, modality, wallet.hotkey.ss58_address, wallet.coldkeypub.ss58_address)):
return False
Expand Down
53 changes: 39 additions & 14 deletions tests/unit_tests/bittensor_tests/test_axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,37 @@
wallet = bittensor.wallet.mock()
axon = bittensor.axon(wallet = wallet)

sender_wallet = bittensor.wallet.mock()

def gen_nonce():
return f"{time.monotonic_ns()}"

def sign(wallet):
nounce = str(int(time.time() * 1000))
receptor_uid = str(uuid.uuid1())
message = "{}{}{}".format(nounce, str(wallet.hotkey.ss58_address), receptor_uid)
def sign_v1(wallet):
nonce, receptor_uid = gen_nonce(), str(uuid.uuid1())
message = "{}{}{}".format(nonce, str(wallet.hotkey.ss58_address), receptor_uid)
spliter = 'bitxx'
signature = spliter.join([ nounce, str(wallet.hotkey.ss58_address), "0x" + wallet.hotkey.sign(message).hex(), receptor_uid])
signature = spliter.join([ nonce, str(wallet.hotkey.ss58_address), "0x" + wallet.hotkey.sign(message).hex(), receptor_uid])
return signature

def test_sign():
sign(wallet)
sign(axon.wallet)
def sign_v2(sender_wallet, receiver_wallet):
nonce, receptor_uid = gen_nonce(), str(uuid.uuid1())
sender_hotkey = sender_wallet.hotkey.ss58_address
receiver_hotkey = receiver_wallet.hotkey.ss58_address
message = f"{nonce}.{sender_hotkey}.{receiver_hotkey}.{receptor_uid}"
signature = f"0x{sender_wallet.hotkey.sign(message).hex()}"
return ".".join([nonce, sender_hotkey, signature, receptor_uid])

def sign(sender_wallet, receiver_wallet, receiver_version):
if receiver_version >= bittensor.__new_signature_version__:
return sign_v2(sender_wallet, receiver_wallet)
return sign_v1(sender_wallet)

def test_sign_v1():
sign_v1(wallet)
sign_v1(axon.wallet)

def test_sign_v2():
sign_v2(sender_wallet, wallet)

def test_forward_wandb():
inputs_raw = torch.rand(3, 3, bittensor.__network_dim__)
Expand Down Expand Up @@ -902,7 +920,7 @@ def forward( inputs_x: torch.FloatTensor, synapses, model_output = None):
assert code == bittensor.proto.ReturnCode.Success


def test_grpc_forward_works():
def run_test_grpc_forward_works(receiver_version):
def forward( inputs_x:torch.FloatTensor, synapse , model_output = None):
return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__])
axon = bittensor.axon (
Expand All @@ -927,14 +945,14 @@ def forward( inputs_x:torch.FloatTensor, synapse , model_output = None):

request = bittensor.proto.TensorMessage(
version = bittensor.__version_as_int__,
hotkey = axon.wallet.hotkey.ss58_address,
hotkey = sender_wallet.hotkey.ss58_address,
tensors = [inputs_serialized],
synapses = [ syn.serialize_to_wire_proto() for syn in synapses ]
)
response = stub.Forward(request,
metadata = (
('rpc-auth-header','Bittensor'),
('bittensor-signature',sign(axon.wallet)),
('bittensor-signature',sign(sender_wallet, wallet, receiver_version)),
('bittensor-version',str(bittensor.__version_as_int__)),
))

Expand All @@ -943,8 +961,11 @@ def forward( inputs_x:torch.FloatTensor, synapse , model_output = None):
assert response.return_code == bittensor.proto.ReturnCode.Success
axon.stop()

def test_grpc_forward_works():
for receiver_version in [341, bittensor.__new_signature_version__, bittensor.__version_as_int__]:
run_test_grpc_forward_works(receiver_version)

def test_grpc_backward_works():
def run_test_grpc_backward_works(receiver_version):
def forward( inputs_x:torch.FloatTensor, synapse , model_output = None):
return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__], requires_grad=True)

Expand All @@ -969,19 +990,23 @@ def forward( inputs_x:torch.FloatTensor, synapse , model_output = None):
grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw)
request = bittensor.proto.TensorMessage(
version = bittensor.__version_as_int__,
hotkey = '1092310312914',
hotkey = sender_wallet.hotkey.ss58_address,
tensors = [inputs_serialized, grads_serialized],
synapses = [ syn.serialize_to_wire_proto() for syn in synapses ]
)
response = stub.Backward(request,
metadata = (
('rpc-auth-header','Bittensor'),
('bittensor-signature',sign(axon.wallet)),
('bittensor-signature',sign(sender_wallet, wallet, receiver_version)),
('bittensor-version',str(bittensor.__version_as_int__)),
))
assert response.return_code == bittensor.proto.ReturnCode.Success
axon.stop()

def test_grpc_backward_works():
for receiver_version in [341, bittensor.__new_signature_version__, bittensor.__version_as_int__]:
run_test_grpc_backward_works(receiver_version)

def test_grpc_forward_fails():
def forward( inputs_x:torch.FloatTensor, synapse, model_output = None):
return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__])
Expand Down
Loading