Skip to content

Commit

Permalink
Merge branch 'feature/components_interface' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Sep 27, 2021
2 parents fc31a95 + 7ccc387 commit 84b11ef
Show file tree
Hide file tree
Showing 54 changed files with 853 additions and 1,863 deletions.
24 changes: 12 additions & 12 deletions src/tribler-core/run_tunnel_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from tribler_common.simpledefs import NTFY

from tribler_core.components.base import Session
from tribler_core.components.interfaces.bandwidth_accounting import BandwidthAccountingComponent
from tribler_core.components.interfaces.ipv8 import Ipv8Component
from tribler_core.components.interfaces.masterkey import MasterKeyComponent
from tribler_core.components.interfaces.resource_monitor import ResourceMonitorComponent
from tribler_core.components.interfaces.restapi import RESTComponent
from tribler_core.components.interfaces.socks_configurator import SocksServersComponent
from tribler_core.components.interfaces.tunnels import TunnelsComponent
from tribler_core.components.interfaces.upgrade import UpgradeComponent
from tribler_core.components.bandwidth_accounting import BandwidthAccountingComponent
from tribler_core.components.ipv8 import Ipv8Component
from tribler_core.components.masterkey import MasterKeyComponent
from tribler_core.components.resource_monitor import ResourceMonitorComponent
from tribler_core.components.restapi import RESTComponent
from tribler_core.components.socks_configurator import SocksServersComponent
from tribler_core.components.tunnels import TunnelsComponent
from tribler_core.components.upgrade import UpgradeComponent
from tribler_core.config.tribler_config import TriblerConfig
from tribler_core.utilities.osutils import get_root_state_directory
from tribler_core.utilities.path_util import Path
Expand Down Expand Up @@ -120,7 +120,7 @@ async def signal_handler(sig):
signal.signal(signal.SIGINT, lambda sig, _: ensure_future(signal_handler(sig)))
signal.signal(signal.SIGTERM, lambda sig, _: ensure_future(signal_handler(sig)))

tunnel_community = TunnelsComponent.imp().community
tunnel_community = TunnelsComponent.instance().community
self.register_task("bootstrap", tunnel_community.bootstrap, interval=30)

# Remove all logging handlers
Expand All @@ -130,7 +130,7 @@ async def signal_handler(sig):
root_logger.removeHandler(handler)
logging.getLogger().setLevel(logging.ERROR)

ipv8 = Ipv8Component.imp().ipv8
ipv8 = Ipv8Component.instance().ipv8
new_strategies = []
with ipv8.overlay_lock:
for strategy, target_peers in ipv8.strategies:
Expand All @@ -141,7 +141,7 @@ async def signal_handler(sig):
ipv8.strategies = new_strategies

def circuit_removed(self, circuit, additional_info):
ipv8 = Ipv8Component.imp().ipv8
ipv8 = Ipv8Component.instance().ipv8
ipv8.network.remove_by_address(circuit.peer.address)
if self.log_circuits:
with open(os.path.join(self.session.config.state_dir, "circuits.log"), 'a') as out_file:
Expand All @@ -162,7 +162,7 @@ async def start(self, options):

with session:
if options.log_rejects:
tunnels_component = TunnelsComponent.imp()
tunnels_component = TunnelsComponent.instance()
tunnels_community = tunnels_component.community
# We set this after Tribler has started since the tunnel_community won't be available otherwise
tunnels_community.reject_callback = self.on_circuit_reject
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from ipv8.peerdiscovery.discovery import RandomWalk

from ipv8_service import IPv8
from tribler_common.simpledefs import STATEDIR_DB_DIR

from tribler_core.components.interfaces.bandwidth_accounting import BandwidthAccountingComponent
from tribler_core.components.interfaces.ipv8 import Ipv8Component
from tribler_core.components.interfaces.reporter import ReporterComponent
from tribler_core.components.interfaces.restapi import RESTComponent
from tribler_core.components.interfaces.upgrade import UpgradeComponent
from tribler_core.components.base import Component
from tribler_core.components.ipv8 import Ipv8Component
from tribler_core.components.reporter import ReporterComponent
from tribler_core.components.restapi import RESTComponent
from tribler_core.components.upgrade import UpgradeComponent
from tribler_core.modules.bandwidth_accounting.community import (
BandwidthAccountingCommunity,
BandwidthAccountingTestnetCommunity,
Expand All @@ -15,18 +14,23 @@
from tribler_core.restapi.rest_manager import RESTManager


class BandwidthAccountingComponentImp(BandwidthAccountingComponent):
rest_manager: RESTManager
class BandwidthAccountingComponent(Component):
community: BandwidthAccountingCommunity

_rest_manager: RESTManager
_ipv8: IPv8

async def run(self):
await self.use(ReporterComponent, required=False)
await self.use(UpgradeComponent, required=False)
await self.get_component(ReporterComponent)
await self.get_component(UpgradeComponent)
config = self.session.config

ipv8_component = await self.use(Ipv8Component)
ipv8 = self._ipv8 = ipv8_component.ipv8
ipv8_component = await self.require_component(Ipv8Component)
self._ipv8 = ipv8_component.ipv8
peer = ipv8_component.peer
rest_manager = self.rest_manager = (await self.use(RESTComponent)).rest_manager

rest_component = await self.require_component(RESTComponent)
self._rest_manager = rest_component.rest_manager

if config.general.testnet or config.bandwidth_accounting.testnet:
bandwidth_cls = BandwidthAccountingTestnetCommunity
Expand All @@ -36,19 +40,18 @@ async def run(self):
db_name = "bandwidth_gui_test.db" if config.gui_test_mode else f"{bandwidth_cls.DB_NAME}.db"
database_path = config.state_dir / STATEDIR_DB_DIR / db_name
database = BandwidthDatabase(database_path, peer.public_key.key_to_bin())
community = bandwidth_cls(peer, ipv8.endpoint, ipv8.network,
community = bandwidth_cls(peer, self._ipv8.endpoint, self._ipv8.network,
settings=config.bandwidth_accounting,
database=database)
ipv8.add_strategy(community, RandomWalk(community), 20)
self._ipv8.add_strategy(community, RandomWalk(community), 20)

community.bootstrappers.append(ipv8_component.make_bootstrapper())

self.community = community

rest_manager.get_endpoint('trustview').bandwidth_db = community.database
rest_manager.get_endpoint('bandwidth').bandwidth_community = community
self._rest_manager.get_endpoint('trustview').bandwidth_db = community.database
self._rest_manager.get_endpoint('bandwidth').bandwidth_community = community

async def shutdown(self):
self.rest_manager.get_endpoint('trustview').bandwidth_db = None
self.rest_manager.get_endpoint('bandwidth').bandwidth_community = None
self._rest_manager.get_endpoint('trustview').bandwidth_db = None
self._rest_manager.get_endpoint('bandwidth').bandwidth_community = None
await self._ipv8.unload_overlay(self.community)
89 changes: 42 additions & 47 deletions src/tribler-core/tribler_core/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import logging
import os
import sys
from abc import abstractmethod
from asyncio import Event, create_task, gather
from itertools import count
from pathlib import Path
from typing import Dict, List, Optional, Set, Type, TypeVar

from tribler_common.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR

from tribler_core.config.tribler_config import TriblerConfig
from tribler_core.notifier import Notifier
from tribler_core.utilities.crypto_patcher import patch_crypto_be_discovery
Expand Down Expand Up @@ -46,8 +44,8 @@ def __init__(self, config: TriblerConfig = None, components: List[Component] = (
self.shutdown_event: Event = shutdown_event or Event()
self.notifier: Notifier = notifier or Notifier()
self.components: Dict[Type[Component], Component] = {}
for implementation in components:
self.register(implementation.interface, implementation)
for component in components:
self.register(component.__class__, component)

def __repr__(self):
return f'<{self.__class__.__name__}:{self.id}>'
Expand Down Expand Up @@ -106,13 +104,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class Component:
enable_in_gui_test_mode = False
enabled = True

def __init__(self, interface: Type[Component]):
assert isinstance(self, interface)
self.interface = interface
self.logger = logging.getLogger(interface.__name__)
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.info('__init__')
self.session: Optional[Session] = None
self.components_used_by_me: Set[Component] = set()
Expand All @@ -125,31 +118,12 @@ def __init__(self, interface: Type[Component]):
self.unused.set()

@classmethod
def should_be_enabled(cls, config: TriblerConfig): # pylint: disable=unused-argument
return False

@classmethod
@abstractmethod
def make_implementation(cls: Type[T], config, enable) -> T:
assert False, f"Abstract classmethod make_implementation not implemented in class {cls.__name__}"

@classmethod
def _find_implementation(cls: Type[T], required=True) -> T:
def instance(cls: Type[T]) -> T:
session = Session.current()
imp = session.components.get(cls)
if imp is None:
if required:
raise ComponentError(f"{cls.__name__} implementation not found in {session}")
imp = cls.make_implementation(session.config, enable=False) # dummy implementation
session.register(cls, imp)
imp.started.set()
return imp

@classmethod
def imp(cls: Type[T], required=True) -> T:
return cls._find_implementation(required=required)
return session.components.get(cls)

async def start(self):
self.logger.info(f'Start: {self.__class__.__name__}')
try:
await self.run()
except Exception as e:
Expand All @@ -163,13 +137,14 @@ async def start(self):
self.started.set()

async def stop(self):
self.logger.info(f'Stop: {self.__class__.__name__}')
self.logger.info("Waiting for other components to release me")
await self.unused.wait()
self.logger.info("Component free, shutting down")
await self.shutdown()
self.stopped = True
for dep in list(self.components_used_by_me):
self._release_imp(dep)
self._release_instance(dep)
self.logger.info("Component free, shutting down")

async def run(self):
Expand All @@ -178,27 +153,47 @@ async def run(self):
async def shutdown(self):
pass

async def use(self, dependency: Type[T], required=True) -> T:
dep = dependency.imp(required=required)
async def require_component(self, dependency: Type[T]) -> T:
""" Resolve the dependency to a component.
The method will wait the component to be initialised.
Returns: The component instance.
In case of a missed or failed dependency an exception will be raised.
"""
dep = await self.get_component(dependency)
if not dep:
raise ComponentError(
f'Missed dependency: {self.__class__.__name__} requires {dependency.__name__} to be active')
return dep

async def get_component(self, dependency: Type[T]) -> Optional[T]:
""" Resolve the dependency to a component.
The method will wait the component to be initialised.
Returns: The component instance.
In case of a missed or failed dependency None will be returned.
"""
dep = dependency.instance()
if not dep:
return None

await dep.started.wait()
if dep.failed:
raise ComponentError(f'Component {self.__class__.__name__} has failed dependency {dep.__class__.__name__}')
self.logger.warning(f'Component {self.__class__.__name__} has failed dependency {dependency.__name__}')
return None

self.components_used_by_me.add(dep)
dep.in_use_by.add(self)
return dep

def _release_imp(self, dep: Component):
async def release_component(self, dependency: Type[T]):
dep = dependency.instance()
if dep:
self._release_instance(dep)

def _release_instance(self, dep: Component):
assert dep in self.components_used_by_me
self.components_used_by_me.discard(dep)
dep.in_use_by.discard(self)
if not dep.in_use_by:
dep.unused.set()

async def release(self, dependency: Type[T]):
dep = dependency.imp()
self._release_imp(dep)


def testcomponent(component_cls):
component_cls.enabled = False
return component_cls
53 changes: 0 additions & 53 deletions src/tribler-core/tribler_core/components/components_catalog.py

This file was deleted.

Loading

0 comments on commit 84b11ef

Please sign in to comment.