Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing transport and relateds. #271

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions bumble/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations

import logging
import asyncio
import itertools
Expand Down Expand Up @@ -58,8 +60,10 @@
HCI_Packet,
HCI_Role_Change_Event,
)
from typing import Optional, Union, Dict
from typing import Optional, Union, Dict, TYPE_CHECKING

if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource

# -----------------------------------------------------------------------------
# Logging
Expand Down Expand Up @@ -104,7 +108,7 @@ def __init__(
self,
name,
host_source=None,
host_sink=None,
host_sink: Optional[TransportSink] = None,
link=None,
public_address: Optional[Union[bytes, str, Address]] = None,
):
Expand Down
44 changes: 37 additions & 7 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@
import logging
from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
Union,
TYPE_CHECKING,
)

from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
Expand Down Expand Up @@ -152,6 +163,9 @@
from . import l2cap
from . import core

if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink


# -----------------------------------------------------------------------------
# Logging
Expand Down Expand Up @@ -942,7 +956,13 @@ def on_characteristic_subscription(
pass

@classmethod
def with_hci(cls, name, address, hci_source, hci_sink):
def with_hci(
cls,
name: str,
address: Address,
hci_source: TransportSource,
hci_sink: TransportSink,
) -> Device:
'''
Create a Device instance with a Host configured to communicate with a controller
through an HCI source/sink
Expand All @@ -951,18 +971,25 @@ def with_hci(cls, name, address, hci_source, hci_sink):
return cls(name=name, address=address, host=host)

@classmethod
def from_config_file(cls, filename):
def from_config_file(cls, filename: str) -> Device:
config = DeviceConfiguration()
config.load_from_file(filename)
return cls(config=config)

@classmethod
def from_config_with_hci(cls, config, hci_source, hci_sink):
def from_config_with_hci(
cls,
config: DeviceConfiguration,
hci_source: TransportSource,
hci_sink: TransportSink,
) -> Device:
host = Host(controller_source=hci_source, controller_sink=hci_sink)
return cls(config=config, host=host)

@classmethod
def from_config_file_with_hci(cls, filename, hci_source, hci_sink):
def from_config_file_with_hci(
cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink
) -> Device:
config = DeviceConfiguration()
config.load_from_file(filename)
return cls.from_config_with_hci(config, hci_source, hci_sink)
Expand Down Expand Up @@ -2238,9 +2265,11 @@ async def pair(self, connection):
def request_pairing(self, connection):
return self.smp_manager.request_pairing(connection)

async def get_long_term_key(self, connection_handle, rand, ediv):
async def get_long_term_key(
self, connection_handle: int, rand: bytes, ediv: int
) -> Optional[bytes]:
if (connection := self.lookup_connection(connection_handle)) is None:
return
return None

# Start by looking for the key in an SMP session
ltk = self.smp_manager.get_long_term_key(connection, rand, ediv)
Expand All @@ -2260,6 +2289,7 @@ async def get_long_term_key(self, connection_handle, rand, ediv):

if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral:
return keys.ltk_peripheral.value
return None

async def get_link_key(self, address: Address) -> Optional[bytes]:
if self.keystore is None:
Expand Down
26 changes: 20 additions & 6 deletions bumble/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import struct

from typing import Optional
from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable

from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
Expand Down Expand Up @@ -73,10 +73,14 @@
BT_LE_TRANSPORT,
ConnectionPHY,
ConnectionParameters,
InvalidStateError,
)
from .utils import AbortableEventEmitter
from .transport.common import TransportLostError

if TYPE_CHECKING:
from .transport.common import TransportSink, TransportSource


# -----------------------------------------------------------------------------
# Logging
Expand Down Expand Up @@ -116,10 +120,21 @@ def on_acl_pdu(self, pdu: bytes) -> None:

# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
def __init__(self, controller_source=None, controller_sink=None):
connections: Dict[int, Connection]
acl_packet_queue: collections.deque[HCI_AclDataPacket]
hci_sink: TransportSink
long_term_key_provider: Optional[
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
]
link_key_provider: Optional[Callable[[Address], Awaitable[Optional[bytes]]]]

def __init__(
self,
controller_source: Optional[TransportSource] = None,
controller_sink: Optional[TransportSink] = None,
) -> None:
super().__init__()

self.hci_sink = None
self.hci_metadata = None
self.ready = False # True when we can accept incoming packets
self.reset_done = False
Expand Down Expand Up @@ -299,7 +314,7 @@ async def reset(self, driver_factory=drivers.get_driver_for_host):
self.reset_done = True

@property
def controller(self):
def controller(self) -> TransportSink:
return self.hci_sink

@controller.setter
Expand All @@ -308,13 +323,12 @@ def controller(self, controller):
if controller:
controller.set_packet_sink(self)

def set_packet_sink(self, sink):
def set_packet_sink(self, sink: TransportSink) -> None:
self.hci_sink = sink

def send_hci_packet(self, packet: HCI_Packet) -> None:
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)

self.hci_sink.on_packet(bytes(packet))

async def send_command(self, command, check_result=False):
Expand Down
11 changes: 7 additions & 4 deletions bumble/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import os

from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..controller import Controller
from ..snoop import create_snooper

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -119,7 +118,8 @@ async def _open_transport(name: str) -> Transport:
if scheme == 'file':
from .file import open_file_transport

return await open_file_transport(spec[0] if spec else None)
assert spec is not None
return await open_file_transport(spec[0])

if scheme == 'vhci':
from .vhci import open_vhci_transport
Expand All @@ -134,12 +134,14 @@ async def _open_transport(name: str) -> Transport:
if scheme == 'usb':
from .usb import open_usb_transport

return await open_usb_transport(spec[0] if spec else None)
assert spec is not None
return await open_usb_transport(spec[0])

if scheme == 'pyusb':
from .pyusb import open_pyusb_transport

return await open_pyusb_transport(spec[0] if spec else None)
assert spec is not None
return await open_pyusb_transport(spec[0])

if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport
Expand Down Expand Up @@ -168,6 +170,7 @@ async def open_transport_or_link(name: str) -> Transport:
"""
if name.startswith('link-relay:'):
from ..controller import Controller
from ..link import RemoteLink # lazy import

link = RemoteLink(name[11:])
Expand Down
7 changes: 4 additions & 3 deletions bumble/transport/android_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import grpc.aio

from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport

# pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
Expand All @@ -33,7 +33,7 @@


# -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec):
async def open_android_emulator_transport(spec: str | None) -> Transport:
'''
Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax:
Expand Down Expand Up @@ -66,7 +66,7 @@ async def write(self, packet):
# Parse the parameters
mode = 'host'
server_host = 'localhost'
server_port = 8554
server_port = '8554'
if spec is not None:
params = spec.split(',')
for param in params:
Expand All @@ -82,6 +82,7 @@ async def write(self, packet):
logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address)

service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub
if mode == 'host':
# Connect as a host
service = EmulatedBluetoothServiceStub(channel)
Expand Down
4 changes: 3 additions & 1 deletion bumble/transport/android_netsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def cleanup():


# -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport(server_host, server_port):
async def open_android_netsim_controller_transport(
server_host: str | None, server_port: int
) -> Transport:
if not server_port:
raise ValueError('invalid port')
if server_host == '_' or not server_host:
Expand Down
Loading