diff --git a/mautrix/__init__.py b/mautrix/__init__.py index aa9cfbe0..8ac5679f 100644 --- a/mautrix/__init__.py +++ b/mautrix/__init__.py @@ -1,3 +1,3 @@ -__version__ = "0.7.10" +__version__ = "0.7.11" __author__ = "Tulir Asokan " __all__ = ["api", "appservice", "bridge", "client", "crypto", "errors", "util", "types"] diff --git a/mautrix/bridge/commands/__init__.py b/mautrix/bridge/commands/__init__.py index 6c37a581..d7b33f1a 100644 --- a/mautrix/bridge/commands/__init__.py +++ b/mautrix/bridge/commands/__init__.py @@ -1,6 +1,7 @@ from .handler import (HelpSection, HelpCacheKey, command_handler, CommandHandler, CommandProcessor, - CommandHandlerFunc, CommandEvent, SECTION_GENERAL) + CommandHandlerFunc, CommandEvent, SECTION_GENERAL, SECTION_ADMIN) from .meta import cancel, unknown_command, help_cmd +from . import admin __all__ = ["HelpSection", "HelpCacheKey", "command_handler", "CommandHandler", "CommandProcessor", - "CommandHandlerFunc", "CommandEvent", "SECTION_GENERAL"] + "CommandHandlerFunc", "CommandEvent", "SECTION_GENERAL", "SECTION_ADMIN"] diff --git a/mautrix/bridge/commands/admin.py b/mautrix/bridge/commands/admin.py new file mode 100644 index 00000000..d8d97cb9 --- /dev/null +++ b/mautrix/bridge/commands/admin.py @@ -0,0 +1,37 @@ +# Copyright (c) 2020 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from mautrix.types import EventID + +from mautrix.errors import (MatrixRequestError, IntentError) + +from .handler import (command_handler, CommandEvent, SECTION_ADMIN) + + +@command_handler(needs_admin=True, needs_auth=False, name="set-pl", + help_section=SECTION_ADMIN, + help_args="<_level_> [_mxid_]", + help_text="Set a temporary power level without affecting the bridge.") +async def set_power_level(evt: CommandEvent) -> EventID: + try: + level = int(evt.args[0]) + except (KeyError, IndexError): + return await evt.reply("**Usage:** `$cmdprefix+sp set-pl [mxid]`") + except ValueError: + return await evt.reply("The level must be an integer.") + if evt.is_portal: + portal = await evt.processor.bridge.get_portal(evt.room_id) + intent = portal.main_intent + else: + intent = evt.az.intent + levels = await intent.get_power_levels(evt.room_id) + mxid = evt.args[1] if len(evt.args) > 1 else evt.sender.mxid + levels.users[mxid] = level + try: + return await intent.set_power_levels(evt.room_id, levels) + except (MatrixRequestError, IntentError): + evt.log.exception("Failed to set power level.") + return await evt.reply("Failed to set power level.") diff --git a/mautrix/bridge/commands/handler.py b/mautrix/bridge/commands/handler.py index dc043305..7404f7da 100644 --- a/mautrix/bridge/commands/handler.py +++ b/mautrix/bridge/commands/handler.py @@ -25,6 +25,7 @@ HelpCacheKey = NamedTuple('HelpCacheKey', is_management=bool, is_portal=bool) SECTION_GENERAL = HelpSection("General", 0, "") +SECTION_ADMIN = HelpSection("Administration", 50, "") def ensure_trailing_newline(s: str) -> str: @@ -112,7 +113,7 @@ def print_error_traceback(self) -> bool: """ return self.is_management - def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True + async def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True ) -> Awaitable[EventID]: """Write a reply to the room in which the command was issued. @@ -136,11 +137,16 @@ def reply(self, message: str, allow_html: bool = False, render_markdown: bool = html = self._render_message(message, allow_html=allow_html, render_markdown=render_markdown) - return self.az.intent.send_notice(self.room_id, message, html=html) + if self.is_portal: + portal = await self.processor.bridge.get_portal(self.room_id) + return await portal.main_intent.send_notice(self.room_id, message, html=html) + else: + return await self.az.intent.send_notice(self.room_id, message, html=html) def mark_read(self) -> Awaitable[None]: """Marks the command as read by the bot.""" - return self.az.intent.mark_read(self.room_id, self.event_id) + if not self.is_portal: + return self.az.intent.mark_read(self.room_id, self.event_id) def _replace_command_prefix(self, message: str) -> str: """Returns the string with the proper command prefix entered.""" @@ -184,20 +190,26 @@ class CommandHandler: name: The name of this command. help_section: Section of the help in which this command will appear. """ - management_only: bool name: str + management_only: bool + needs_admin: bool + needs_auth: bool + _help_text: str _help_args: str help_section: HelpSection def __init__(self, handler: CommandHandlerFunc, management_only: bool, name: str, - help_text: str, help_args: str, help_section: HelpSection, **kwargs) -> None: + help_text: str, help_args: str, help_section: HelpSection, + needs_auth: bool, needs_admin: bool, **kwargs) -> None: """ Args: handler: The function handling the execution of this command. management_only: Whether the command can exclusively be issued in a management room. + needs_auth: Whether the command needs the bridge to be authed already + needs_admin: Whether the command needs the issuer to be bridge admin name: The name of this command. help_text: The text displayed in the help for this command. help_args: Help text for the arguments of this command. @@ -207,6 +219,8 @@ def __init__(self, handler: CommandHandlerFunc, management_only: bool, name: str setattr(self, key, value) self._handler = handler self.management_only = management_only + self.needs_admin = needs_admin + self.needs_auth = needs_auth self.name = name self._help_text = help_text self._help_args = help_args @@ -224,6 +238,10 @@ async def get_permission_error(self, evt: CommandEvent) -> Optional[str]: if self.management_only and not evt.is_management: return (f"`{evt.command}` is a restricted command: " "you may only run it in management rooms.") + elif self.needs_admin and not evt.sender.is_admin: + return "This command requires administrator privileges." + elif self.needs_auth and not await evt.sender.is_logged_in(): + return "This command requires you to be logged in." return None def has_permission(self, key: HelpCacheKey) -> bool: @@ -236,7 +254,9 @@ def has_permission(self, key: HelpCacheKey) -> bool: True if a user with the given state is allowed to issue the command. """ - return not self.management_only or key.is_management + return ((not self.management_only or key.is_management) and + (not self.needs_admin or key.is_admin) and + (not self.needs_auth or key.is_logged_in)) async def __call__(self, evt: CommandEvent) -> Any: """Executes the command if evt was issued with proper rights. @@ -267,13 +287,14 @@ def command_handler(_func: Optional[CommandHandlerFunc] = None, *, management_on name: Optional[str] = None, help_text: str = "", help_args: str = "", help_section: HelpSection = None, aliases: Optional[List[str]] = None, _handler_class: Type[CommandHandler] = CommandHandler, + needs_auth: bool = True, needs_admin: bool = False, **kwargs) -> Callable[[CommandHandlerFunc], CommandHandler]: """Decorator to create CommandHandlers""" def decorator(func: CommandHandlerFunc) -> CommandHandler: actual_name = name or func.__name__.replace("_", "-") handler = _handler_class(func, management_only, actual_name, help_text, help_args, - help_section, **kwargs) + help_section, needs_auth, needs_admin, **kwargs) command_handlers[handler.name] = handler if aliases: for alias in aliases: diff --git a/mautrix/crypto/encrypt_megolm.py b/mautrix/crypto/encrypt_megolm.py index f776c3f8..c6c829e2 100644 --- a/mautrix/crypto/encrypt_megolm.py +++ b/mautrix/crypto/encrypt_megolm.py @@ -3,7 +3,7 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Tuple from collections import defaultdict from datetime import timedelta import asyncio @@ -18,7 +18,7 @@ from .types import DeviceIdentity, TrustState from .encrypt_olm import OlmEncryptionMachine from .device_lists import DeviceListMachine -from .sessions import OutboundGroupSession, InboundGroupSession +from .sessions import OutboundGroupSession, InboundGroupSession, Session class Sentinel: @@ -28,9 +28,11 @@ class Sentinel: already_shared = Sentinel() key_missing = Sentinel() +DeviceSessionWrapper = Tuple[Session, DeviceIdentity] +DeviceMap = Dict[UserID, Dict[DeviceID, DeviceSessionWrapper]] SessionEncryptResult = Union[ type(already_shared), # already shared - EncryptedOlmEventContent, # share successful + DeviceSessionWrapper, # share successful RoomKeyWithheldEventContent, # won't share type(key_missing), # missing device ] @@ -38,13 +40,11 @@ class Sentinel: class MegolmEncryptionMachine(OlmEncryptionMachine, DeviceListMachine): _megolm_locks: Dict[RoomID, asyncio.Lock] - _olm_locks: Dict[IdentityKey, asyncio.Lock] _sharing_group_session: Dict[RoomID, asyncio.Event] def __init__(self) -> None: super().__init__() self._megolm_locks = defaultdict(lambda: asyncio.Lock()) - self._olm_locks = defaultdict(lambda: asyncio.Lock()) self._sharing_group_session = {} async def encrypt_megolm_event(self, room_id: RoomID, event_type: EventType, content: Any @@ -89,8 +89,8 @@ async def _encrypt_megolm_event(self, room_id: RoomID, event_type: EventType, co relates_to = None await self.crypto_store.update_outbound_group_session(session) return EncryptedMegolmEventContent(sender_key=self.account.identity_key, - device_id=self.client.device_id, session_id=session.id, - ciphertext=ciphertext, relates_to=relates_to) + device_id=self.client.device_id, ciphertext=ciphertext, + session_id=SessionID(session.id), relates_to=relates_to) def is_sharing_group_session(self, room_id: RoomID) -> bool: """ @@ -158,7 +158,7 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No self.log.debug("Got stored encryption state event and configured session to rotate " f"after {session.max_messages} messages or {session.max_age}") - share_key_msgs = defaultdict(lambda: {}) + olm_sessions: DeviceMap = defaultdict(lambda: {}) withhold_key_msgs = defaultdict(lambda: {}) missing_sessions: Dict[UserID, Dict[DeviceID, DeviceIdentity]] = defaultdict(lambda: {}) fetch_keys = [] @@ -173,13 +173,13 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No else: self.log.debug(f"Trying to encrypt group session {session.id} for {user_id}") for device_id, device in devices.items(): - result = await self._encrypt_group_session(session, user_id, device_id, device) - if isinstance(result, EncryptedOlmEventContent): - share_key_msgs[user_id][device_id] = result - elif isinstance(result, RoomKeyWithheldEventContent): + result = await self._find_olm_sessions(session, user_id, device_id, device) + if isinstance(result, RoomKeyWithheldEventContent): withhold_key_msgs[user_id][device_id] = result elif result == key_missing: missing_sessions[user_id][device_id] = device + elif isinstance(result, tuple): + olm_sessions[user_id][device_id] = result if fetch_keys: self.log.debug(f"Fetching missing keys for {fetch_keys}") @@ -193,17 +193,16 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No for user_id, devices in missing_sessions.items(): for device_id, device in devices.items(): - result = await self._encrypt_group_session(session, user_id, device_id, device) - if isinstance(result, EncryptedOlmEventContent): - share_key_msgs[user_id][device_id] = result - elif isinstance(result, RoomKeyWithheldEventContent): + result = await self._find_olm_sessions(session, user_id, device_id, device) + if isinstance(result, RoomKeyWithheldEventContent): withhold_key_msgs[user_id][device_id] = result + elif isinstance(result, tuple): + olm_sessions[user_id][device_id] = result # We don't care about missing keys at this point - if len(share_key_msgs) > 0: - event_count = sum(len(map) for map in share_key_msgs.values()) - self.log.debug(f"Sending {event_count} to-device events to share {session.id}") - await self.client.send_to_device(EventType.TO_DEVICE_ENCRYPTED, share_key_msgs) + if len(olm_sessions) > 0: + async with self._olm_lock: + await self._encrypt_and_share_group_session(session, olm_sessions) if len(withhold_key_msgs) > 0: event_count = sum(len(map) for map in withhold_key_msgs.values()) self.log.debug(f"Sending {event_count} to-device events " @@ -221,6 +220,19 @@ async def _new_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSes room_id, SessionID(session.id), session.session_key) return session + async def _encrypt_and_share_group_session(self, session: OutboundGroupSession, + olm_sessions: DeviceMap): + msgs = defaultdict(lambda: {}) + count = 0 + for user_id, devices in olm_sessions.items(): + count += len(devices) + for device_id, (olm_session, device_identity) in devices.items(): + msgs[user_id][device_id] = await self._encrypt_olm_event( + olm_session, device_identity, EventType.ROOM_KEY, session.share_content) + self.log.debug(f"Sending to-device events to {count} devices of {len(msgs)} users " + f"to share {session.id}") + await self.client.send_to_device(EventType.TO_DEVICE_ENCRYPTED, msgs) + async def _create_group_session(self, sender_key: IdentityKey, signing_key: SigningKey, room_id: RoomID, session_id: SessionID, session_key: str ) -> None: @@ -231,15 +243,9 @@ async def _create_group_session(self, sender_key: IdentityKey, signing_key: Sign self._mark_session_received(session_id) self.log.debug(f"Created inbound group session {room_id}/{sender_key}/{session_id}") - async def _encrypt_group_session(self, session: OutboundGroupSession, user_id: UserID, - device_id: DeviceID, device: DeviceIdentity - ) -> SessionEncryptResult: - async with self._olm_locks[device.identity_key]: - return await self._encrypt_group_session_locked(session, user_id, device_id, device) - - async def _encrypt_group_session_locked(self, session: OutboundGroupSession, user_id: UserID, - device_id: DeviceID, device: DeviceIdentity - ) -> SessionEncryptResult: + async def _find_olm_sessions(self, session: OutboundGroupSession, user_id: UserID, + device_id: DeviceID, device: DeviceIdentity + ) -> SessionEncryptResult: key = (user_id, device_id) if key in session.users_ignored or key in session.users_shared_with: return already_shared @@ -267,8 +273,5 @@ async def _encrypt_group_session_locked(self, session: OutboundGroupSession, use device_session = await self.crypto_store.get_latest_session(device.identity_key) if not device_session: return key_missing - encrypted = await self._encrypt_olm_event(device_session, device, EventType.ROOM_KEY, - session.share_content) session.users_shared_with.add(key) - self.log.debug(f"Encrypted group session {session.id} for {device_id} of {user_id}") - return encrypted + return device_session, device diff --git a/mautrix/crypto/encrypt_olm.py b/mautrix/crypto/encrypt_olm.py index 2b9a2aeb..68178419 100644 --- a/mautrix/crypto/encrypt_olm.py +++ b/mautrix/crypto/encrypt_olm.py @@ -18,9 +18,11 @@ class OlmEncryptionMachine(BaseOlmMachine): _claim_keys_lock: asyncio.Lock + _olm_lock: asyncio.Lock def __init__(self): self._claim_keys_lock = asyncio.Lock() + self._olm_lock = asyncio.Lock() async def _encrypt_olm_event(self, session: Session, recipient: DeviceIdentity, event_type: EventType, content: Any) -> EncryptedOlmEventContent: @@ -66,6 +68,7 @@ async def send_encrypted_to_device(self, device: DeviceIdentity, event_type: Eve content: ToDeviceEventContent) -> None: await self._create_outbound_sessions({device.user_id: {device.device_id: device}}) session = await self.crypto_store.get_latest_session(device.identity_key) - encrypted_content = await self._encrypt_olm_event(session, device, event_type, content) - await self.client.send_to_one_device(EventType.TO_DEVICE_ENCRYPTED, device.user_id, - device.device_id, encrypted_content) + async with self._olm_lock: + encrypted_content = await self._encrypt_olm_event(session, device, event_type, content) + await self.client.send_to_one_device(EventType.TO_DEVICE_ENCRYPTED, device.user_id, + device.device_id, encrypted_content) diff --git a/mautrix/crypto/store/asyncpg/store.py b/mautrix/crypto/store/asyncpg/store.py index 26e82210..9496593b 100644 --- a/mautrix/crypto/store/asyncpg/store.py +++ b/mautrix/crypto/store/asyncpg/store.py @@ -4,6 +4,7 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, Optional, List +from collections import defaultdict from mautrix.types import SyncToken, IdentityKey, SessionID, RoomID, EventID, UserID, DeviceID from mautrix.client.state_store import SyncStore @@ -27,6 +28,7 @@ class PgCryptoStore(CryptoStore, SyncStore): _sync_token: Optional[SyncToken] _device_id: Optional[DeviceID] _account: Optional[OlmAccount] + _olm_cache: Dict[IdentityKey, Dict[SessionID, Session]] def __init__(self, account_id: str, pickle_key: str, db: Database) -> None: self.db = db @@ -36,6 +38,7 @@ def __init__(self, account_id: str, pickle_key: str, db: Database) -> None: self._sync_token = None self._device_id = "" self._account = None + self._olm_cache = defaultdict(lambda: {}) async def get_device_id(self) -> Optional[DeviceID]: device_id = await self.db.fetchval("SELECT device_id FROM crypto_account " @@ -79,32 +82,45 @@ async def get_account(self) -> OlmAccount: return self._account async def has_session(self, key: IdentityKey) -> bool: + if len(self._olm_cache[key]) > 0: + return True val = await self.db.fetchval("SELECT session_id FROM crypto_olm_session " "WHERE sender_key=$1 AND account_id=$2", key, self.account_id) return val is not None async def get_sessions(self, key: IdentityKey) -> List[Session]: - rows = await self.db.fetch("SELECT session, created_at, last_used FROM crypto_olm_session " - "WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id", + rows = await self.db.fetch("SELECT session_id, session, created_at, last_used " + "FROM crypto_olm_session " + "WHERE sender_key=$1 AND account_id=$2 " + "ORDER BY session_id", key, self.account_id) sessions = [] for row in rows: - sess = Session.from_pickle(row["session"], passphrase=self.pickle_key, - creation_time=row["created_at"], use_time=row["last_used"]) + try: + sess = self._olm_cache[key][row["session_id"]] + except KeyError: + sess = Session.from_pickle(row["session"], passphrase=self.pickle_key, + creation_time=row["created_at"], + use_time=row["last_used"]) sessions.append(sess) return sessions async def get_latest_session(self, key: IdentityKey) -> Optional[Session]: - row = await self.db.fetchrow("SELECT session, created_at, last_used FROM crypto_olm_session" - " WHERE sender_key=$1 AND account_id=$2" - " ORDER BY session DESC LIMIT 1", key, self.account_id) + row = await self.db.fetchrow("SELECT session_id, session, created_at, last_used " + "FROM crypto_olm_session " + "WHERE sender_key=$1 AND account_id=$2 " + "ORDER BY session_id DESC LIMIT 1", key, self.account_id) if row is None: return None - return Session.from_pickle(row["session"], passphrase=self.pickle_key, - creation_time=row["created_at"], use_time=row["last_used"]) + try: + return self._olm_cache[key][row["session_id"]] + except KeyError: + return Session.from_pickle(row["session"], passphrase=self.pickle_key, + creation_time=row["created_at"], use_time=row["last_used"]) async def add_session(self, key: IdentityKey, session: Session) -> None: pickle = session.pickle(self.pickle_key) + self._olm_cache[key][session.id] = session await self.db.execute("INSERT INTO crypto_olm_session (session_id, sender_key, session, " "created_at, last_used, account_id) VALUES ($1, $2, $3, $4, $5, $6)", session.id, key, pickle, session.creation_time, session.use_time,