From 00e660d41065a948c89f475c76cd8bc87878d57e Mon Sep 17 00:00:00 2001 From: William Escande Date: Mon, 9 Sep 2024 16:24:22 -0700 Subject: [PATCH] Implement Hap support (#532) * Implement Hap --- bumble/att.py | 11 + bumble/gatt.py | 5 + bumble/gatt_server.py | 2 +- bumble/profiles/hap.py | 666 +++++++++++++++++++++++++++++++++++++ examples/run_hap_server.py | 107 ++++++ tests/hap_test.py | 227 +++++++++++++ 6 files changed, 1017 insertions(+), 1 deletion(-) create mode 100644 bumble/profiles/hap.py create mode 100644 examples/run_hap_server.py create mode 100644 tests/hap_test.py diff --git a/bumble/att.py b/bumble/att.py index e7995ae9..86d7fc60 100644 --- a/bumble/att.py +++ b/bumble/att.py @@ -23,6 +23,8 @@ # Imports # ----------------------------------------------------------------------------- from __future__ import annotations +from bumble.utils import OpenIntEnum + import enum import functools import inspect @@ -211,6 +213,15 @@ class ErrorCode(utils.OpenIntEnum): # pylint: disable=invalid-name +class CommonErrorCode(OpenIntEnum): + '''See Supplement to the Bluetooth Code Specification 1.2 List of Error Codes.''' + + WRITE_REQUEST_REJECTED = 0xFC + CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR_IMPROPERLY_CONFIGURED = 0xFD + PROCEDURE_ALREADY_IN_PROGRESS = 0xFE + OUT_OF_RANGE = 0xFF + + # ----------------------------------------------------------------------------- # Exceptions # ----------------------------------------------------------------------------- diff --git a/bumble/gatt.py b/bumble/gatt.py index a14fcda3..3e679bbe 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -275,6 +275,11 @@ GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts') GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts') +# Hearing Access Service +GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features') +GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point') +GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC = UUID.from_16_bits(0x2BDC, 'Active Preset Index') + # ASHA Service GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid') GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties') diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index 302fb4fe..0ee673c0 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -915,7 +915,7 @@ async def on_att_write_request(self, connection, request): See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request ''' - # Check that the attribute exists + # Check that the attribute exists attribute = self.get_attribute(request.attribute_handle) if attribute is None: self.send_response( diff --git a/bumble/profiles/hap.py b/bumble/profiles/hap.py new file mode 100644 index 00000000..e61ac4f8 --- /dev/null +++ b/bumble/profiles/hap.py @@ -0,0 +1,666 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +import functools +from bumble import att, gatt, gatt_client +from bumble.att import CommonErrorCode +from bumble.core import InvalidArgumentError, InvalidStateError +from bumble.device import Device, Connection +from bumble.utils import AsyncRunner, OpenIntEnum +from bumble.hci import Address +from dataclasses import dataclass, field +import logging +from typing import Dict, List, Optional, Set, Union + + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +class ErrorCode(OpenIntEnum): + '''See Hearing Access Service 2.4. Attribute Profile error codes.''' + + INVALID_OPCODE = 0x80 + WRITE_NAME_NOT_ALLOWED = 0x81 + PRESET_SYNCHRONIZATION_NOT_SUPPORTED = 0x82 + PRESET_OPERATION_NOT_POSSIBLE = 0x83 + INVALID_PARAMETERS_LENGTH = 0x84 + + +class HearingAidType(OpenIntEnum): + '''See Hearing Access Service 3.1. Hearing Aid Features.''' + + BINAURAL_HEARING_AID = 0b00 + MONAURAL_HEARING_AID = 0b01 + BANDED_HEARING_AID = 0b10 + + +class PresetSynchronizationSupport(OpenIntEnum): + '''See Hearing Access Service 3.1. Hearing Aid Features.''' + + PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0 + PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1 + + +class IndependentPresets(OpenIntEnum): + '''See Hearing Access Service 3.1. Hearing Aid Features.''' + + IDENTICAL_PRESET_RECORD = 0b0 + DIFFERENT_PRESET_RECORD = 0b1 + + +class DynamicPresets(OpenIntEnum): + '''See Hearing Access Service 3.1. Hearing Aid Features.''' + + PRESET_RECORDS_DOES_NOT_CHANGE = 0b0 + PRESET_RECORDS_MAY_CHANGE = 0b1 + + +class WritablePresetsSupport(OpenIntEnum): + '''See Hearing Access Service 3.1. Hearing Aid Features.''' + + WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0 + WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1 + + +class HearingAidPresetControlPointOpcode(OpenIntEnum): + '''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.''' + + # fmt: off + READ_PRESETS_REQUEST = 0x01 + READ_PRESET_RESPONSE = 0x02 + PRESET_CHANGED = 0x03 + WRITE_PRESET_NAME = 0x04 + SET_ACTIVE_PRESET = 0x05 + SET_NEXT_PRESET = 0x06 + SET_PREVIOUS_PRESET = 0x07 + SET_ACTIVE_PRESET_SYNCHRONIZED_LOCALLY = 0x08 + SET_NEXT_PRESET_SYNCHRONIZED_LOCALLY = 0x09 + SET_PREVIOUS_PRESET_SYNCHRONIZED_LOCALLY = 0x0A + + +@dataclass +class HearingAidFeatures: + '''See Hearing Access Service 3.1. Hearing Aid Features.''' + + hearing_aid_type: HearingAidType + preset_synchronization_support: PresetSynchronizationSupport + independent_presets: IndependentPresets + dynamic_presets: DynamicPresets + writable_presets_support: WritablePresetsSupport + + def __bytes__(self) -> bytes: + return bytes( + [ + (self.hearing_aid_type << 0) + | (self.preset_synchronization_support << 2) + | (self.independent_presets << 3) + | (self.dynamic_presets << 4) + | (self.writable_presets_support << 5) + ] + ) + + +def HearingAidFeatures_from_bytes(data: int) -> HearingAidFeatures: + return HearingAidFeatures( + HearingAidType(data & 0b11), + PresetSynchronizationSupport(data >> 2 & 0b1), + IndependentPresets(data >> 3 & 0b1), + DynamicPresets(data >> 4 & 0b1), + WritablePresetsSupport(data >> 5 & 0b1), + ) + + +@dataclass +class PresetChangedOperation: + '''See Hearing Access Service 3.2.2.2. Preset Changed operation.''' + + class ChangeId(OpenIntEnum): + # fmt: off + GENERIC_UPDATE = 0x00 + PRESET_RECORD_DELETED = 0x01 + PRESET_RECORD_AVAILABLE = 0x02 + PRESET_RECORD_UNAVAILABLE = 0x03 + + @dataclass + class Generic: + prev_index: int + preset_record: PresetRecord + + def __bytes__(self) -> bytes: + return bytes([self.prev_index]) + bytes(self.preset_record) + + change_id: ChangeId + additional_parameters: Union[Generic, int] + + def to_bytes(self, is_last: bool) -> bytes: + if isinstance(self.additional_parameters, PresetChangedOperation.Generic): + additional_parameters_bytes = bytes(self.additional_parameters) + else: + additional_parameters_bytes = bytes([self.additional_parameters]) + + return ( + bytes( + [ + HearingAidPresetControlPointOpcode.PRESET_CHANGED, + self.change_id, + is_last, + ] + ) + + additional_parameters_bytes + ) + + +class PresetChangedOperationDeleted(PresetChangedOperation): + def __init__(self, index) -> None: + self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_DELETED + self.additional_parameters = index + + +class PresetChangedOperationAvailable(PresetChangedOperation): + def __init__(self, index) -> None: + self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_AVAILABLE + self.additional_parameters = index + + +class PresetChangedOperationUnavailable(PresetChangedOperation): + def __init__(self, index) -> None: + self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_UNAVAILABLE + self.additional_parameters = index + + +@dataclass +class PresetRecord: + '''See Hearing Access Service 2.8. Preset record.''' + + @dataclass + class Property: + class Writable(OpenIntEnum): + CANNOT_BE_WRITTEN = 0b0 + CAN_BE_WRITTEN = 0b1 + + class IsAvailable(OpenIntEnum): + IS_UNAVAILABLE = 0b0 + IS_AVAILABLE = 0b1 + + writable: Writable = Writable.CAN_BE_WRITTEN + is_available: IsAvailable = IsAvailable.IS_AVAILABLE + + def __bytes__(self) -> bytes: + return bytes([self.writable | (self.is_available << 1)]) + + index: int + name: str + properties: Property = field(default_factory=Property) + + def __bytes__(self) -> bytes: + return bytes([self.index]) + bytes(self.properties) + self.name.encode('utf-8') + + def is_available(self) -> bool: + return ( + self.properties.is_available + == PresetRecord.Property.IsAvailable.IS_AVAILABLE + ) + + +# ----------------------------------------------------------------------------- +# Server +# ----------------------------------------------------------------------------- +class HearingAccessService(gatt.TemplateService): + UUID = gatt.GATT_HEARING_ACCESS_SERVICE + + hearing_aid_features_characteristic: gatt.Characteristic + hearing_aid_preset_control_point: gatt.Characteristic + active_preset_index_characteristic: gatt.Characteristic + active_preset_index: int + active_preset_index_per_device: Dict[Address, int] + + device: Device + + server_features: HearingAidFeatures + preset_records: Dict[int, PresetRecord] # key is the preset index + read_presets_request_in_progress: bool + + preset_changed_operations_history_per_device: Dict[ + Address, List[PresetChangedOperation] + ] + + # Keep an updated list of connected client to send notification to + currently_connected_clients: Set[Connection] + + def __init__( + self, device: Device, features: HearingAidFeatures, presets: List[PresetRecord] + ) -> None: + self.active_preset_index_per_device = {} + self.read_presets_request_in_progress = False + self.preset_changed_operations_history_per_device = {} + self.currently_connected_clients = set() + + self.device = device + self.server_features = features + if len(presets) < 1: + raise InvalidArgumentError(f'Invalid presets: {presets}') + + self.preset_records = {} + for p in presets: + if len(p.name.encode()) < 1 or len(p.name.encode()) > 40: + raise InvalidArgumentError(f'Invalid name: {p.name}') + + self.preset_records[p.index] = p + + # associate the lowest index as the current active preset at startup + self.active_preset_index = sorted(self.preset_records.keys())[0] + + @device.on('connection') # type: ignore + def on_connection(connection: Connection) -> None: + @connection.on('disconnection') # type: ignore + def on_disconnection(_reason) -> None: + self.currently_connected_clients.remove(connection) + + # TODO Should we filter on device bonded && device is HAP ? + self.currently_connected_clients.add(connection) + if ( + connection.peer_address + not in self.preset_changed_operations_history_per_device + ): + self.preset_changed_operations_history_per_device[ + connection.peer_address + ] = [] + return + + async def on_connection_async() -> None: + # Send all the PresetChangedOperation that occur when not connected + await self._preset_changed_operation(connection) + # Update the active preset index if needed + await self.notify_active_preset_for_connection(connection) + + connection.abort_on('disconnection', on_connection_async()) + + self.hearing_aid_features_characteristic = gatt.Characteristic( + uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC, + properties=gatt.Characteristic.Properties.READ, + permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, + value=bytes(self.server_features), + ) + self.hearing_aid_preset_control_point = gatt.Characteristic( + uuid=gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC, + properties=( + gatt.Characteristic.Properties.WRITE + | gatt.Characteristic.Properties.INDICATE + ), + permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION, + value=gatt.CharacteristicValue( + write=self._on_write_hearing_aid_preset_control_point + ), + ) + self.active_preset_index_characteristic = gatt.Characteristic( + uuid=gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC, + properties=( + gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.NOTIFY + ), + permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, + value=gatt.CharacteristicValue(read=self._on_read_active_preset_index), + ) + + super().__init__( + [ + self.hearing_aid_features_characteristic, + self.hearing_aid_preset_control_point, + self.active_preset_index_characteristic, + ] + ) + + def _on_read_active_preset_index( + self, __connection__: Optional[Connection] + ) -> bytes: + return bytes([self.active_preset_index]) + + # TODO this need to be triggered when device is unbonded + def on_forget(self, addr: Address) -> None: + self.preset_changed_operations_history_per_device.pop(addr) + + async def _on_write_hearing_aid_preset_control_point( + self, connection: Optional[Connection], value: bytes + ): + assert connection + + opcode = HearingAidPresetControlPointOpcode(value[0]) + handler = getattr(self, '_on_' + opcode.name.lower()) + await handler(connection, value) + + async def _on_read_presets_request( + self, connection: Optional[Connection], value: bytes + ): + assert connection + if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements + logging.warning(f'HAS require MTU >= 49: {connection}') + + if self.read_presets_request_in_progress: + raise att.ATT_Error(CommonErrorCode.PROCEDURE_ALREADY_IN_PROGRESS) + self.read_presets_request_in_progress = True + + start_index = value[1] + if start_index == 0x00: + raise att.ATT_Error(CommonErrorCode.OUT_OF_RANGE) + + num_presets = value[2] + if num_presets == 0x00: + raise att.ATT_Error(CommonErrorCode.OUT_OF_RANGE) + + # Sending `num_presets` presets ordered by increasing index field, starting from start_index + presets = [ + self.preset_records[key] + for key in sorted(self.preset_records.keys()) + if self.preset_records[key].index >= start_index + ] + del presets[num_presets:] + if len(presets) == 0: + raise att.ATT_Error(CommonErrorCode.OUT_OF_RANGE) + + AsyncRunner.spawn(self._read_preset_response(connection, presets)) + + async def _read_preset_response( + self, connection: Connection, presets: List[PresetRecord] + ): + # If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects. + try: + for i, preset in enumerate(presets): + await connection.device.indicate_subscriber( + connection, + self.hearing_aid_preset_control_point, + value=bytes( + [ + HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, + i == len(presets) - 1, + ] + ) + + bytes(preset), + ) + + finally: + # indicate_subscriber can raise a TimeoutError, we need to gracefully terminate the operation + self.read_presets_request_in_progress = False + + async def generic_update(self, op: PresetChangedOperation) -> None: + '''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent''' + await self._notifyPresetOperations(op) + + async def delete_preset(self, index: int) -> None: + '''Server API to delete a preset. It should not be the current active preset''' + + if index == self.active_preset_index: + raise InvalidStateError('Cannot delete active preset') + + del self.preset_records[index] + await self._notifyPresetOperations(PresetChangedOperationDeleted(index)) + + async def available_preset(self, index: int) -> None: + '''Server API to make a preset available''' + + preset = self.preset_records[index] + preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE + await self._notifyPresetOperations(PresetChangedOperationAvailable(index)) + + async def unavailable_preset(self, index: int) -> None: + '''Server API to make a preset unavailable. It should not be the current active preset''' + + if index == self.active_preset_index: + raise InvalidStateError('Cannot set active preset as unavailable') + + preset = self.preset_records[index] + preset.properties.is_available = ( + PresetRecord.Property.IsAvailable.IS_UNAVAILABLE + ) + await self._notifyPresetOperations(PresetChangedOperationUnavailable(index)) + + async def _preset_changed_operation(self, connection: Connection) -> None: + '''Send all PresetChangedOperation saved for a given connection''' + op_list = self.preset_changed_operations_history_per_device.get( + connection.peer_address, [] + ) + + # Notification will be sent in index order + def get_op_index(op: PresetChangedOperation) -> int: + if isinstance(op.additional_parameters, PresetChangedOperation.Generic): + return op.additional_parameters.prev_index + return op.additional_parameters + + op_list.sort(key=get_op_index) + # If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects. + while len(op_list) > 0: + try: + await connection.device.indicate_subscriber( + connection, + self.hearing_aid_preset_control_point, + value=op_list[0].to_bytes(len(op_list) == 1), + ) + # Remove item once sent, and keep the non sent item in the list + op_list.pop(0) + except TimeoutError: + break + + async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None: + for historyList in self.preset_changed_operations_history_per_device.values(): + historyList.append(op) + + for connection in self.currently_connected_clients: + await self._preset_changed_operation(connection) + + async def _on_write_preset_name( + self, connection: Optional[Connection], value: bytes + ): + assert connection + + if self.read_presets_request_in_progress: + raise att.ATT_Error(CommonErrorCode.PROCEDURE_ALREADY_IN_PROGRESS) + + index = value[1] + preset = self.preset_records.get(index, None) + if ( + not preset + or preset.properties.writable + == PresetRecord.Property.Writable.CANNOT_BE_WRITTEN + ): + raise att.ATT_Error(ErrorCode.WRITE_NAME_NOT_ALLOWED) + + name = value[2:].decode('utf-8') + if not name or len(name) > 40: + raise att.ATT_Error(ErrorCode.INVALID_PARAMETERS_LENGTH) + + preset.name = name + + await self.generic_update( + PresetChangedOperation( + PresetChangedOperation.ChangeId.GENERIC_UPDATE, + PresetChangedOperation.Generic(index, preset), + ) + ) + + async def notify_active_preset_for_connection(self, connection: Connection) -> None: + if ( + self.active_preset_index_per_device.get(connection.peer_address, 0x00) + == self.active_preset_index + ): + # Nothing to do, peer is already updated + return + + await connection.device.notify_subscriber( + connection, + attribute=self.active_preset_index_characteristic, + value=bytes([self.active_preset_index]), + ) + self.active_preset_index_per_device[connection.peer_address] = ( + self.active_preset_index + ) + + async def notify_active_preset(self) -> None: + for connection in self.currently_connected_clients: + await self.notify_active_preset_for_connection(connection) + + async def set_active_preset( + self, connection: Optional[Connection], value: bytes + ) -> None: + assert connection + index = value[1] + preset = self.preset_records.get(index, None) + if ( + not preset + or preset.properties.is_available + != PresetRecord.Property.IsAvailable.IS_AVAILABLE + ): + raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) + + if index == self.active_preset_index: + # Already at correct value + return + + self.active_preset_index = index + await self.notify_active_preset() + + async def _on_set_active_preset( + self, connection: Optional[Connection], value: bytes + ): + await self.set_active_preset(connection, value) + + async def set_next_or_previous_preset( + self, connection: Optional[Connection], is_previous + ): + '''Set the next or the previous preset as active''' + assert connection + + if self.active_preset_index == 0x00: + raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) + + first_preset: Optional[PresetRecord] = None # To loop to first preset + next_preset: Optional[PresetRecord] = None + for index, record in sorted(self.preset_records.items(), reverse=is_previous): + if not record.is_available(): + continue + if first_preset == None: + first_preset = record + if is_previous: + if index >= self.active_preset_index: + continue + elif index <= self.active_preset_index: + continue + next_preset = record + break + + if not first_preset: # If no other preset are available + raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) + + if next_preset: + self.active_preset_index = next_preset.index + else: + self.active_preset_index = first_preset.index + await self.notify_active_preset() + + async def _on_set_next_preset( + self, connection: Optional[Connection], __value__: bytes + ) -> None: + await self.set_next_or_previous_preset(connection, False) + + async def _on_set_previous_preset( + self, connection: Optional[Connection], __value__: bytes + ) -> None: + await self.set_next_or_previous_preset(connection, True) + + async def _on_set_active_preset_synchronized_locally( + self, connection: Optional[Connection], value: bytes + ): + if ( + self.server_features.preset_synchronization_support + == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED + ): + raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED) + await self.set_active_preset(connection, value) + # TODO (low priority) inform other server of the change + + async def _on_set_next_preset_synchronized_locally( + self, connection: Optional[Connection], __value__: bytes + ): + if ( + self.server_features.preset_synchronization_support + == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED + ): + raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED) + await self.set_next_or_previous_preset(connection, False) + # TODO (low priority) inform other server of the change + + async def _on_set_previous_preset_synchronized_locally( + self, connection: Optional[Connection], __value__: bytes + ): + if ( + self.server_features.preset_synchronization_support + == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED + ): + raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED) + await self.set_next_or_previous_preset(connection, True) + # TODO (low priority) inform other server of the change + + +# ----------------------------------------------------------------------------- +# Client +# ----------------------------------------------------------------------------- +class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy): + SERVICE_CLASS = HearingAccessService + + hearing_aid_preset_control_point: gatt_client.CharacteristicProxy + preset_control_point_indications: asyncio.Queue + + def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: + self.service_proxy = service_proxy + + self.server_features = gatt.PackedCharacteristicAdapter( + service_proxy.get_characteristics_by_uuid( + gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC + )[0], + 'B', + ) + + self.hearing_aid_preset_control_point = ( + service_proxy.get_characteristics_by_uuid( + gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC + )[0] + ) + + self.active_preset_index = gatt.PackedCharacteristicAdapter( + service_proxy.get_characteristics_by_uuid( + gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC + )[0], + 'B', + ) + + async def setup_subscription(self): + self.preset_control_point_indications = asyncio.Queue() + self.active_preset_index_notification = asyncio.Queue() + + def on_active_preset_index_notification(data: bytes): + self.active_preset_index_notification.put_nowait(data) + + def on_preset_control_point_indication(data: bytes): + self.preset_control_point_indications.put_nowait(data) + + await self.hearing_aid_preset_control_point.subscribe( + functools.partial(on_preset_control_point_indication), prefer_notify=False + ) + + await self.active_preset_index.subscribe( + functools.partial(on_active_preset_index_notification) + ) diff --git a/examples/run_hap_server.py b/examples/run_hap_server.py new file mode 100644 index 00000000..18f1c387 --- /dev/null +++ b/examples/run_hap_server.py @@ -0,0 +1,107 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import logging +import sys +import os + +from bumble.core import AdvertisingData +from bumble.device import Device +from bumble import att +from bumble.profiles.hap import ( + HearingAccessService, + HearingAidFeatures, + HearingAidType, + PresetSynchronizationSupport, + IndependentPresets, + DynamicPresets, + WritablePresetsSupport, + PresetRecord, +) + +from bumble.transport import open_transport_or_link + +server_features = HearingAidFeatures( + HearingAidType.MONAURAL_HEARING_AID, + PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED, + IndependentPresets.IDENTICAL_PRESET_RECORD, + DynamicPresets.PRESET_RECORDS_DOES_NOT_CHANGE, + WritablePresetsSupport.WRITABLE_PRESET_RECORDS_SUPPORTED, +) + +foo_preset = PresetRecord(1, "foo preset") +bar_preset = PresetRecord(50, "bar preset") +foobar_preset = PresetRecord(5, "foobar preset") + + +# ----------------------------------------------------------------------------- +async def main() -> None: + if len(sys.argv) < 3: + print('Usage: run_hap_server.py ') + print('example: run_hap_server.py device1.json pty:hci_pty') + return + + print('<<< connecting to HCI...') + async with await open_transport_or_link(sys.argv[2]) as hci_transport: + print('<<< connected') + + device = Device.from_config_file_with_hci( + sys.argv[1], hci_transport.source, hci_transport.sink + ) + + await device.power_on() + + hap = HearingAccessService( + device, server_features, [foo_preset, bar_preset, foobar_preset] + ) + device.add_service(hap) + + advertising_data = bytes( + AdvertisingData( + [ + ( + AdvertisingData.COMPLETE_LOCAL_NAME, + bytes('Bumble HearingAccessService', 'utf-8'), + ), + ( + AdvertisingData.FLAGS, + bytes( + [ + AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG + | AdvertisingData.BR_EDR_HOST_FLAG + | AdvertisingData.BR_EDR_CONTROLLER_FLAG + ] + ), + ), + ( + AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, + bytes(HearingAccessService.UUID), + ), + ] + ) + ) + + await device.create_advertising_set( + advertising_data=advertising_data, + auto_restart=True, + ) + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +asyncio.run(main()) diff --git a/tests/hap_test.py b/tests/hap_test.py new file mode 100644 index 00000000..58392fd1 --- /dev/null +++ b/tests/hap_test.py @@ -0,0 +1,227 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import pytest +import functools +import pytest_asyncio +import logging +import sys + +from bumble import att, device +from bumble.profiles import hap +from .test_utils import TwoDevices + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +foo_preset = hap.PresetRecord(1, "foo preset") +bar_preset = hap.PresetRecord(50, "bar preset") +foobar_preset = hap.PresetRecord(5, "foobar preset") +unavailable_preset = hap.PresetRecord( + 78, + "foobar preset", + hap.PresetRecord.Property( + hap.PresetRecord.Property.Writable.CANNOT_BE_WRITTEN, + hap.PresetRecord.Property.IsAvailable.IS_UNAVAILABLE, + ), +) + +server_features = hap.HearingAidFeatures( + hap.HearingAidType.MONAURAL_HEARING_AID, + hap.PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED, + hap.IndependentPresets.IDENTICAL_PRESET_RECORD, + hap.DynamicPresets.PRESET_RECORDS_DOES_NOT_CHANGE, + hap.WritablePresetsSupport.WRITABLE_PRESET_RECORDS_SUPPORTED, +) + +TIMEOUT = 0.1 + + +async def assert_queue_is_empty(queue: asyncio.Queue): + assert queue.empty() + + # Check that nothing is being added during TIMEOUT secondes + if sys.version_info >= (3, 11): + with pytest.raises(TimeoutError): + await asyncio.wait_for(queue.get(), TIMEOUT) + else: + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(queue.get(), TIMEOUT) + + +# ----------------------------------------------------------------------------- +@pytest_asyncio.fixture +async def hap_client(): + devices = TwoDevices() + devices[0].add_service( + hap.HearingAccessService( + devices[0], + server_features, + [foo_preset, bar_preset, foobar_preset, unavailable_preset], + ) + ) + + await devices.setup_connection() + # TODO negotiate MTU > 49 to not truncate preset names + + # Mock encryption. + devices.connections[0].encryption = 1 # type: ignore + devices.connections[1].encryption = 1 # type: ignore + + peer = device.Peer(devices.connections[1]) # type: ignore + hap_client = await peer.discover_service_and_create_proxy( + hap.HearingAccessServiceProxy + ) + assert hap_client + await hap_client.setup_subscription() + + yield hap_client + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_init_service(hap_client: hap.HearingAccessServiceProxy): + assert ( + hap.HearingAidFeatures_from_bytes(await hap_client.server_features.read_value()) + == server_features + ) + assert (await hap_client.active_preset_index.read_value()) == (foo_preset.index) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_read_all_presets(hap_client: hap.HearingAccessServiceProxy): + await hap_client.hearing_aid_preset_control_point.write_value( + bytes([hap.HearingAidPresetControlPointOpcode.READ_PRESETS_REQUEST, 1, 0xFF]) + ) + assert (await hap_client.preset_control_point_indications.get()) == bytes( + [hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0] + ) + bytes(foo_preset) + assert (await hap_client.preset_control_point_indications.get()) == bytes( + [hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0] + ) + bytes(foobar_preset) + assert (await hap_client.preset_control_point_indications.get()) == bytes( + [hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0] + ) + bytes(bar_preset) + assert (await hap_client.preset_control_point_indications.get()) == bytes( + [hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 1] + ) + bytes(unavailable_preset) + + await assert_queue_is_empty(hap_client.preset_control_point_indications) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_read_partial_presets(hap_client: hap.HearingAccessServiceProxy): + await hap_client.hearing_aid_preset_control_point.write_value( + bytes([hap.HearingAidPresetControlPointOpcode.READ_PRESETS_REQUEST, 3, 2]) + ) + assert (await hap_client.preset_control_point_indications.get())[2:] == bytes( + foobar_preset + ) + assert (await hap_client.preset_control_point_indications.get())[2:] == bytes( + bar_preset + ) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_set_active_preset_valid(hap_client: hap.HearingAccessServiceProxy): + await hap_client.hearing_aid_preset_control_point.write_value( + bytes( + [hap.HearingAidPresetControlPointOpcode.SET_ACTIVE_PRESET, bar_preset.index] + ) + ) + assert (await hap_client.active_preset_index_notification.get()) == bar_preset.index + + assert (await hap_client.active_preset_index.read_value()) == (bar_preset.index) + + await assert_queue_is_empty(hap_client.active_preset_index_notification) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_set_active_preset_invalid(hap_client: hap.HearingAccessServiceProxy): + with pytest.raises(att.ATT_Error) as e: + await hap_client.hearing_aid_preset_control_point.write_value( + bytes( + [ + hap.HearingAidPresetControlPointOpcode.SET_ACTIVE_PRESET, + unavailable_preset.index, + ] + ), + with_response=True, + ) + assert e.value.error_code == hap.ErrorCode.PRESET_OPERATION_NOT_POSSIBLE + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_set_next_preset(hap_client: hap.HearingAccessServiceProxy): + await hap_client.hearing_aid_preset_control_point.write_value( + bytes([hap.HearingAidPresetControlPointOpcode.SET_NEXT_PRESET]) + ) + assert ( + await hap_client.active_preset_index_notification.get() + ) == foobar_preset.index + + assert (await hap_client.active_preset_index.read_value()) == (foobar_preset.index) + + await assert_queue_is_empty(hap_client.active_preset_index_notification) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_set_next_preset_will_loop_to_first( + hap_client: hap.HearingAccessServiceProxy, +): + async def go_next(new_preset: hap.PresetRecord): + await hap_client.hearing_aid_preset_control_point.write_value( + bytes([hap.HearingAidPresetControlPointOpcode.SET_NEXT_PRESET]) + ) + assert ( + await hap_client.active_preset_index_notification.get() + ) == new_preset.index + + assert (await hap_client.active_preset_index.read_value()) == (new_preset.index) + + await go_next(foobar_preset) + await go_next(bar_preset) + await go_next(foo_preset) + + # Note that there is a invalid preset in the preset record of the server + + await assert_queue_is_empty(hap_client.active_preset_index_notification) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_set_previous_preset_will_loop_to_last( + hap_client: hap.HearingAccessServiceProxy, +): + await hap_client.hearing_aid_preset_control_point.write_value( + bytes([hap.HearingAidPresetControlPointOpcode.SET_PREVIOUS_PRESET]) + ) + assert (await hap_client.active_preset_index_notification.get()) == bar_preset.index + + assert (await hap_client.active_preset_index.read_value()) == (bar_preset.index) + + await assert_queue_is_empty(hap_client.active_preset_index_notification)