Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #6360: Optional dependency on REST manager in components #6381

Merged
merged 7 commits into from
Sep 29, 2021
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from ipv8.peerdiscovery.discovery import RandomWalk

from ipv8_service import IPv8

from tribler_common.simpledefs import STATEDIR_DB_DIR

from tribler_core.components.base import Component
from tribler_core.components.ipv8 import Ipv8Component
from tribler_core.components.reporter import ReporterComponent
from tribler_core.components.restapi import RESTComponent
from tribler_core.components.restapi import RestfulComponent
from tribler_core.components.upgrade import UpgradeComponent
from tribler_core.components.bandwidth_accounting.community.community import (
BandwidthAccountingCommunity,
Expand All @@ -14,24 +17,20 @@
from tribler_core.restapi.rest_manager import RESTManager


class BandwidthAccountingComponent(Component):
class BandwidthAccountingComponent(RestfulComponent):
community: BandwidthAccountingCommunity

_rest_manager: RESTManager
_ipv8: IPv8

async def run(self):
kozlovsky marked this conversation as resolved.
Show resolved Hide resolved
await self.get_component(ReporterComponent)
await super().run()
await self.get_component(UpgradeComponent)
config = self.session.config

ipv8_component = await self.require_component(Ipv8Component)
self._ipv8 = ipv8_component.ipv8
peer = ipv8_component.peer

rest_component = await self.require_component(RESTComponent)
self._rest_manager = rest_component.rest_manager

if config.general.testnet or config.bandwidth_accounting.testnet:
bandwidth_cls = BandwidthAccountingTestnetCommunity
else:
Expand All @@ -48,10 +47,9 @@ async def run(self):
community.bootstrappers.append(ipv8_component.make_bootstrapper())

self.community = community
self._rest_manager.get_endpoint('trustview').bandwidth_db = community.database
self._rest_manager.get_endpoint('bandwidth').bandwidth_community = community
await self.init_endpoints(endpoints=['trustview', 'bandwidth'],
values={'bandwidth_db': community.database, 'bandwidth_community': community})

async def shutdown(self):
self._rest_manager.get_endpoint('trustview').bandwidth_db = None
self._rest_manager.get_endpoint('bandwidth').bandwidth_community = None
await super().shutdown()
await self._ipv8.unload_overlay(self.community)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class BandwidthEndpoint(RESTEndpoint):

def __init__(self):
super().__init__()
self.bandwidth_db = None # added to simlify the initialization code of BandwidthAccountingComponent
self.bandwidth_community = None

def setup_routes(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@


async def test_bandwidth_accounting_component(tribler_config):
tribler_config.ipv8.enabled = True
components = [RESTComponent(), MasterKeyComponent(), Ipv8Component(), BandwidthAccountingComponent()]
session = Session(tribler_config, components)
with session:
comp = BandwidthAccountingComponent.instance()
with patch.object(RESTManager, 'get_endpoint'):
await session.start()
await session.start()

assert comp.community
assert comp._rest_manager
assert comp._ipv8
comp = BandwidthAccountingComponent.instance()
assert comp.started.is_set() and not comp.failed
assert comp.community
assert comp._ipv8

await session.shutdown()
await session.shutdown()
2 changes: 1 addition & 1 deletion src/tribler-core/tribler_core/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ async def get_component(self, dependency: Type[T]) -> Optional[T]:
dep.in_use_by.add(self)
return dep

async def release_component(self, dependency: Type[T]):
def release_component(self, dependency: Type[T]):
dep = dependency.instance()
if dep:
self._release_instance(dep)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from ipv8.peerdiscovery.discovery import RandomWalk

from ipv8_service import IPv8

from tribler_core.components.base import Component
from tribler_core.components.ipv8 import Ipv8Component
from tribler_core.components.gigachannel.community.gigachannel_community import GigaChannelCommunity, \
GigaChannelTestnetCommunity
from tribler_core.components.metadata_store.metadata_store_component import MetadataStoreComponent
from tribler_core.components.reporter import ReporterComponent
from tribler_core.components.restapi import RESTComponent
from tribler_core.components.restapi import RestfulComponent
from tribler_core.components.gigachannel.community.sync_strategy import RemovePeers
from tribler_core.restapi.rest_manager import RESTManager

INFINITE = -1


class GigaChannelComponent(Component):
class GigaChannelComponent(RestfulComponent):
community: GigaChannelCommunity

_rest_manager: RESTManager
_ipv8: IPv8

async def run(self):
await super().run()
await self.get_component(ReporterComponent)

config = self.session.config
Expand All @@ -29,9 +31,6 @@ async def run(self):
self._ipv8 = ipv8_component.ipv8
peer = ipv8_component.peer

rest_component = await self.require_component(RESTComponent)
self._rest_manager = rest_component.rest_manager

metadata_store_component = await self.require_component(MetadataStoreComponent)

giga_channel_cls = GigaChannelTestnetCommunity if config.general.testnet else GigaChannelCommunity
Expand All @@ -52,14 +51,10 @@ async def run(self):

community.bootstrappers.append(ipv8_component.make_bootstrapper())

self._rest_manager.get_endpoint('remote_query').gigachannel_community = community
self._rest_manager.get_endpoint('channels').gigachannel_community = community
self._rest_manager.get_endpoint('collections').gigachannel_community = community
await self.init_endpoints(endpoints=['remote_query', 'channels', 'collections'],
values={'gigachannel_community': community})

async def shutdown(self):
self._rest_manager.get_endpoint('remote_query').gigachannel_community = None
self._rest_manager.get_endpoint('channels').gigachannel_community = None
self._rest_manager.get_endpoint('collections').gigachannel_community = None
await self.release_component(RESTComponent)
await super().shutdown()
if self._ipv8:
await self._ipv8.unload_overlay(self.community)
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@


async def test_giga_channel_component(tribler_config):
tribler_config.ipv8.enabled = True
tribler_config.libtorrent.enabled = True
tribler_config.chant.enabled = True
components = [MetadataStoreComponent(), RESTComponent(), MasterKeyComponent(), Ipv8Component(),
GigaChannelComponent()]
session = Session(tribler_config, components)
with session:
comp = GigaChannelComponent.instance()
with patch.object(RESTManager, 'get_endpoint'):
await session.start()
await session.start()

assert comp.community
assert comp._rest_manager
assert comp._ipv8
comp = GigaChannelComponent.instance()
assert comp.started.is_set() and not comp.failed
assert comp.community
assert comp._ipv8

await session.shutdown()
await session.shutdown()
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
from tribler_core.components.libtorrent import LibtorrentComponent
from tribler_core.components.metadata_store.metadata_store_component import MetadataStoreComponent
from tribler_core.components.reporter import ReporterComponent
from tribler_core.components.restapi import RESTComponent
from tribler_core.components.restapi import RestfulComponent
from tribler_core.components.gigachannel_manager.gigachannel_manager import GigaChannelManager
from tribler_core.restapi.rest_manager import RESTManager


class GigachannelManagerComponent(Component):
class GigachannelManagerComponent(RestfulComponent):
gigachannel_manager: GigaChannelManager

_rest_manager: RESTManager

async def run(self):
await self.get_component(ReporterComponent)
await super().run()

config = self.session.config
notifier = self.session.notifier
Expand All @@ -22,24 +20,17 @@ async def run(self):
download_manager = libtorrent_component.download_manager if libtorrent_component else None

metadata_store_component = await self.require_component(MetadataStoreComponent)
rest_component = await self.require_component(RESTComponent)

self._rest_manager = rest_component.rest_manager

self.gigachannel_manager = GigaChannelManager(
notifier=notifier, metadata_store=metadata_store_component.mds, download_manager=download_manager
)
if not config.gui_test_mode:
self.gigachannel_manager.start()

self._rest_manager.get_endpoint('channels').gigachannel_manager = self.gigachannel_manager
self._rest_manager.get_endpoint('collections').gigachannel_manager = self.gigachannel_manager
await self.init_endpoints(endpoints=['channels', 'collections'],
values={'gigachannel_manager': self.gigachannel_manager})

async def shutdown(self):
self.session.notifier.notify_shutdown_state("Shutting down Gigachannel Manager...")
self._rest_manager.get_endpoint('channels').gigachannel_manager = None
self._rest_manager.get_endpoint('collections').gigachannel_manager = None

await self.release_component(RESTComponent)

await super().shutdown()
await self.gigachannel_manager.shutdown()
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
# pylint: disable=protected-access

async def test_gigachannel_manager_component(tribler_config):
tribler_config.ipv8.enabled = True
tribler_config.libtorrent.enabled = True
tribler_config.chant.enabled = True
components = [SocksServersComponent(), MasterKeyComponent(), RESTComponent(), MetadataStoreComponent(),
LibtorrentComponent(), GigachannelManagerComponent()]
session = Session(tribler_config, components)
with session:
comp = GigachannelManagerComponent.instance()
with patch.object(RESTManager, 'get_endpoint'):
await session.start()
await session.start()

assert comp.gigachannel_manager
assert comp._rest_manager
assert comp.started.is_set() and not comp.failed
assert comp.gigachannel_manager

await session.shutdown()
await session.shutdown()
29 changes: 8 additions & 21 deletions src/tribler-core/tribler_core/components/ipv8.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,25 @@
from tribler_core.components.base import Component
from tribler_core.components.masterkey import MasterKeyComponent
from tribler_core.components.reporter import ReporterComponent
from tribler_core.components.restapi import RESTComponent
from tribler_core.components.restapi import RestfulComponent
from tribler_core.restapi.rest_manager import RESTManager

INFINITE = -1


class Ipv8Component(Component):
class Ipv8Component(RestfulComponent):
ipv8: IPv8
peer: Peer
dht_discovery_community: Optional[DHTDiscoveryCommunity] = None

_task_manager: TaskManager
_rest_manager: Optional[RESTManager]
_peer_discovery_community: Optional[DiscoveryCommunity] = None

async def run(self):
await self.get_component(ReporterComponent)
await super().run()

config = self.session.config

rest_component = await self.get_component(RESTComponent)
self._rest_manager = rest_component.rest_manager if rest_component else None

self._task_manager = TaskManager()

port = config.ipv8.port
Expand Down Expand Up @@ -81,9 +77,6 @@ async def run(self):
config.ipv8.walk_interval,
config.ipv8.walk_scaling_upper_limit).start(self._task_manager)

if self._rest_manager:
self._rest_manager.get_endpoint('statistics').ipv8 = ipv8

if config.dht.enabled:
self.init_dht_discovery_community()

Expand All @@ -94,13 +87,10 @@ async def run(self):
if config.dht.enabled:
self.dht_discovery_community.routing_tables[UDPv4Address] = RoutingTable('\x00' * 20)

endpoints_to_init = ['/asyncio', '/attestation', '/dht', '/identity',
'/isolation', '/network', '/noblockdht', '/overlays']

if self._rest_manager:
for path, endpoint in self._rest_manager.get_endpoint('ipv8').endpoints.items():
if path in endpoints_to_init:
endpoint.initialize(ipv8)
await self.init_endpoints(endpoints=['statistics'], values={'ipv8': ipv8})
await self.init_ipv8_endpoints(ipv8, endpoints=[
'asyncio', 'attestation', 'dht', 'identity', 'isolation', 'network', 'noblockdht', 'overlays'
])

def make_bootstrapper(self) -> DispersyBootstrapper:
args = DISPERSY_BOOTSTRAPPER['init']
Expand All @@ -127,15 +117,12 @@ def init_dht_discovery_community(self):
self.dht_discovery_community = community

async def shutdown(self):
if self._rest_manager:
self._rest_manager.get_endpoint('statistics').ipv8 = None
await self.release_component(RESTComponent)
await super().shutdown()

for overlay in (self.dht_discovery_community, self._peer_discovery_community):
if overlay:
await self.ipv8.unload_overlay(overlay)

await self.unused.wait()
self.session.notifier.notify_shutdown_state("Shutting down IPv8...")
await self._task_manager.shutdown_task_manager()
await self.ipv8.stop(stop_loop=False)
Loading