Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
egbertbouman committed May 21, 2018
1 parent 2520923 commit 5e87c44
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 16 deletions.
42 changes: 42 additions & 0 deletions Tribler/Test/Community/DHT/test_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from twisted.internet.defer import inlineCallbacks, succeed

from Tribler.community.dht.provider import DHTCommunityProvider
from Tribler.pyipv8.ipv8.util import blocking_call_on_reactor_thread
from Tribler.Test.Core.base_test import TriblerCoreTest, MockObject


class TestDHTProvider(TriblerCoreTest):

@blocking_call_on_reactor_thread
@inlineCallbacks
def setUp(self, annotate=True):
yield super(TestDHTProvider, self).setUp(annotate=annotate)

def mocked_find_values(key):
return succeed(['\x01\x01\x01\x01\x04\xd2'])

def mocked_store(key, value):
self.stored_value = value
return succeed([])

self.cb_invoked = False
self.stored_value = None
self.dhtcommunity = MockObject()
self.dhtcommunity.find_values = mocked_find_values
self.dhtcommunity.store = mocked_store
self.dhtcommunity.my_estimated_lan = '1.1.1.1'
self.dht_provider = DHTCommunityProvider(self.dhtcommunity, 1234)

def test_lookup(self):
def check_result(result):
self.cb_invoked = True
self.assertEqual(result, [('1.1.1.1', 1234)])
self.dht_provider.lookup('a' * 20, check_result)
self.assertTrue(self.cb_invoked)

def test_announce(self):
def check_result(result):
self.cb_invoked = True
self.dht_provider.announce('a' * 20, check_result)
self.assertTrue(self.cb_invoked)
self.assertEqual(self.stored_value, '\x01\x01\x01\x01\x04\xd2')
2 changes: 1 addition & 1 deletion Tribler/community/dht/community.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _send_find_request(self, node, target, force_nodes):
def find(self, target, force_nodes=False):
nodes_closest = set(self.routing_table.closest_nodes(target, max_nodes=MAX_FIND_WALKS))
if not nodes_closest:
returnValue(fail(Failure(RuntimeError("No nodes found in the routing table"))))
returnValue(Failure(RuntimeError("No nodes found in the routing table")))

nodes_tried = set()
values = set()
Expand Down
19 changes: 12 additions & 7 deletions Tribler/community/dht/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ def to_pack_list(self):
values_str = ''.join([pack('!H', len(value)) + value for value in self.values])
data.append(('varlenH', values_str))

nodes_str = ''.join([inet_aton(node.address[0]) +
pack("!H", node.address[1]) +
node.public_key.key_to_bin() for node in self.nodes])
nodes_str = ''
for node in self.nodes:
key = node.public_key.key_to_bin()
nodes_str += inet_aton(node.address[0]) + pack("!H", node.address[1])
nodes_str += pack('!H', len(key)) + key
data.append(('varlenH', nodes_str))

return data
Expand All @@ -111,10 +113,13 @@ def from_unpack_list(cls, identifier, values_str, nodes_str):
index += length

nodes = []
for i in xrange(0, len(nodes_str), 80):
key = nodes_str[i + 6:i + 80]
address = (inet_ntoa(nodes_str[i:i + 4]),
unpack('!H', nodes_str[i + 4:i + 6])[0])
index = 0
while index < len(nodes_str):
ip, port, key_length = unpack('!4sHH', nodes_str[index:index + 8])
index += 8
address = (inet_ntoa(ip), port)
key = nodes_str[index:index + key_length]
index += key_length
nodes.append(Node(key, address=address))

return FindResponsePayload(identifier, values, nodes)
23 changes: 15 additions & 8 deletions Tribler/community/dht/provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from Tribler.pyipv8.ipv8.messaging.deprecated.encoding import encode, decode
import struct
import socket


class DHTCommunityProvider(object):
Expand All @@ -15,12 +15,19 @@ def __init__(self, dhtcommunity, bt_port):

def lookup(self, info_hash, cb):
def callback(values):
values = [decode(v)[1] for v in values]
cb(info_hash, values, None)
self.dhtcommunity.find_values(info_hash).addCallbacks(callback, lambda _: None)
addresses = []
for v in values:
try:
ip, port = struct.unpack('!4sH', v)
address = (socket.inet_ntoa(ip), port)
addresses.append(address)
except struct.error, socket.error:
self.logger.info("Failed to decode value '%s' from DHTCommunity", v)
return addresses
self.dhtcommunity.find_values(info_hash).addCallback(callback).addCallback(cb)

def announce(self, info_hash):
def announce(self, info_hash, cb):
def callback(node):
self.logger.info("Announced %s to the DHTCommunity", info_hash.encode('hex'))
value = encode((self.dhtcommunity.my_estimated_lan, self.bt_port))
self.dhtcommunity.store(info_hash, value).addCallbacks(callback, lambda _: None)
value = socket.inet_aton(self.dhtcommunity.my_estimated_lan) + struct.pack("!H", self.bt_port)
self.dhtcommunity.store(info_hash, value).addCallback(callback).addCallback(cb)

0 comments on commit 5e87c44

Please sign in to comment.