diff --git a/apps/bench.py b/apps/bench.py index 67299dbc..7d934ced 100644 --- a/apps/bench.py +++ b/apps/bench.py @@ -16,6 +16,7 @@ # Imports # ----------------------------------------------------------------------------- import asyncio +import dataclasses import enum import logging import os @@ -97,34 +98,6 @@ # ----------------------------------------------------------------------------- # Utils # ----------------------------------------------------------------------------- -def parse_packet(packet): - if len(packet) < 1: - logging.info( - color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red') - ) - raise ValueError('packet too short') - - try: - packet_type = PacketType(packet[0]) - except ValueError: - logging.info(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red')) - raise - - return (packet_type, packet[1:]) - - -def parse_packet_sequence(packet_data): - if len(packet_data) < 5: - logging.info( - color( - f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)', - 'red', - ) - ) - raise ValueError('packet too short') - return struct.unpack_from('>bI', packet_data, 0) - - def le_phy_name(phy_id): return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get( phy_id, HCI_Constant.le_phy_name(phy_id) @@ -225,13 +198,135 @@ async def switch_roles(connection, role): logging.info(f'{color("### Role switch failed:", "red")} {error}') -class PacketType(enum.IntEnum): - RESET = 0 - SEQUENCE = 1 - ACK = 2 +# ----------------------------------------------------------------------------- +# Packet +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class Packet: + class PacketType(enum.IntEnum): + RESET = 0 + SEQUENCE = 1 + ACK = 2 + + class PacketFlags(enum.IntFlag): + LAST = 1 + + packet_type: PacketType + flags: PacketFlags = PacketFlags(0) + sequence: int = 0 + timestamp: int = 0 + payload: bytes = b"" + + @classmethod + def from_bytes(cls, data: bytes): + if len(data) < 1: + logging.warning( + color(f'!!! Packet too short (got {len(data)} bytes, need >= 1)', 'red') + ) + raise ValueError('packet too short') + + try: + packet_type = cls.PacketType(data[0]) + except ValueError: + logging.warning(color(f'!!! Invalid packet type 0x{data[0]:02X}', 'red')) + raise + + if packet_type == cls.PacketType.RESET: + return cls(packet_type) + + flags = cls.PacketFlags(data[1]) + (sequence,) = struct.unpack_from("= 6)', + 'red', + ) + ) + return cls(packet_type, flags, sequence) + + if len(data) < 10: + logging.warning( + color( + f'!!! Packet too short (got {len(data)} bytes, need >= 10)', 'red' + ) + ) + raise ValueError('packet too short') + + (timestamp,) = struct.unpack_from(" 1: + expected_time = ( + self.receive_times[0] + + (packet.timestamp - self.packets[0].timestamp) / 1000000 + ) + jitter = now - expected_time + else: + jitter = 0.0 + self.jitter.append(jitter) + return jitter -PACKET_FLAG_LAST = 1 + def show_stats(self): + if len(self.jitter) < 3: + return + average = sum(self.jitter) / len(self.jitter) + adjusted = [jitter - average for jitter in self.jitter] + + log_stats('Jitter (signed)', adjusted, 3) + log_stats('Jitter (absolute)', [abs(jitter) for jitter in adjusted], 3) + + # Show a histogram + bin_count = 20 + bins = [0] * bin_count + interval_min = min(adjusted) + interval_max = max(adjusted) + interval_range = interval_max - interval_min + bin_thresholds = [ + interval_min + i * (interval_range / bin_count) for i in range(bin_count) + ] + for jitter in adjusted: + for i in reversed(range(bin_count)): + if jitter >= bin_thresholds[i]: + bins[i] += 1 + break + for i in range(bin_count): + logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}') # ----------------------------------------------------------------------------- @@ -281,19 +376,37 @@ async def run(self): await asyncio.sleep(self.tx_start_delay) logging.info(color('=== Sending RESET', 'magenta')) - await self.packet_io.send_packet(bytes([PacketType.RESET])) + await self.packet_io.send_packet( + bytes(Packet(packet_type=Packet.PacketType.RESET)) + ) + self.start_time = time.time() self.bytes_sent = 0 for tx_i in range(self.tx_packet_count): - packet_flags = ( - PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0 + if self.pace > 0: + # Wait until it is time to send the next packet + target_time = self.start_time + (tx_i * self.pace / 1000) + now = time.time() + if now < target_time: + await asyncio.sleep(target_time - now) + else: + await self.packet_io.drain() + + packet = bytes( + Packet( + packet_type=Packet.PacketType.SEQUENCE, + flags=( + Packet.PacketFlags.LAST + if tx_i == self.tx_packet_count - 1 + else 0 + ), + sequence=tx_i, + timestamp=int((time.time() - self.start_time) * 1000000), + payload=bytes( + self.tx_packet_size - 10 - self.packet_io.overhead_size + ), + ) ) - packet = struct.pack( - '>bbI', - PacketType.SEQUENCE, - packet_flags, - tx_i, - ) + bytes(self.tx_packet_size - 6 - self.packet_io.overhead_size) logging.info( color( f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow' @@ -302,14 +415,6 @@ async def run(self): self.bytes_sent += len(packet) await self.packet_io.send_packet(packet) - if self.pace is None: - continue - - if self.pace > 0: - await asyncio.sleep(self.pace / 1000) - else: - await self.packet_io.drain() - await self.done.wait() run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else '' @@ -321,13 +426,13 @@ async def run(self): if self.repeat: logging.info(color('--- End of runs', 'blue')) - def on_packet_received(self, packet): + def on_packet_received(self, data): try: - packet_type, _ = parse_packet(packet) + packet = Packet.from_bytes(data) except ValueError: return - if packet_type == PacketType.ACK: + if packet.packet_type == Packet.PacketType.ACK: elapsed = time.time() - self.start_time average_tx_speed = self.bytes_sent / elapsed self.stats.append(average_tx_speed) @@ -350,52 +455,53 @@ class Receiver: last_timestamp: float def __init__(self, packet_io, linger): - self.reset() + self.jitter_stats = JitterStats() self.packet_io = packet_io self.packet_io.packet_listener = self self.linger = linger self.done = asyncio.Event() + self.reset() def reset(self): self.expected_packet_index = 0 self.measurements = [(time.time(), 0)] self.total_bytes_received = 0 + self.jitter_stats.reset() - def on_packet_received(self, packet): + def on_packet_received(self, data): try: - packet_type, packet_data = parse_packet(packet) + packet = Packet.from_bytes(data) except ValueError: + logging.exception("invalid packet") return - if packet_type == PacketType.RESET: + if packet.packet_type == Packet.PacketType.RESET: logging.info(color('=== Received RESET', 'magenta')) self.reset() return - try: - packet_flags, packet_index = parse_packet_sequence(packet_data) - except ValueError: - return + jitter = self.jitter_stats.on_packet_received(packet) logging.info( - f'<<< Received packet {packet_index}: ' - f'flags=0x{packet_flags:02X}, ' - f'{len(packet) + self.packet_io.overhead_size} bytes' + f'<<< Received packet {packet.sequence}: ' + f'flags={packet.flags}, ' + f'jitter={jitter:.4f}, ' + f'{len(data) + self.packet_io.overhead_size} bytes', ) - if packet_index != self.expected_packet_index: + if packet.sequence != self.expected_packet_index: logging.info( color( f'!!! Unexpected packet, expected {self.expected_packet_index} ' - f'but received {packet_index}' + f'but received {packet.sequence}' ) ) now = time.time() elapsed_since_start = now - self.measurements[0][0] elapsed_since_last = now - self.measurements[-1][0] - self.measurements.append((now, len(packet))) - self.total_bytes_received += len(packet) - instant_rx_speed = len(packet) / elapsed_since_last + self.measurements.append((now, len(data))) + self.total_bytes_received += len(data) + instant_rx_speed = len(data) / elapsed_since_last average_rx_speed = self.total_bytes_received / elapsed_since_start window = self.measurements[-64:] windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / ( @@ -411,15 +517,17 @@ def on_packet_received(self, packet): ) ) - self.expected_packet_index = packet_index + 1 + self.expected_packet_index = packet.sequence + 1 - if packet_flags & PACKET_FLAG_LAST: + if packet.flags & Packet.PacketFlags.LAST: AsyncRunner.spawn( self.packet_io.send_packet( - struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index) + bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence)) ) ) logging.info(color('@@@ Received last packet', 'green')) + self.jitter_stats.show_stats() + if not self.linger: self.done.set() @@ -479,25 +587,32 @@ async def run(self): await asyncio.sleep(self.tx_start_delay) logging.info(color('=== Sending RESET', 'magenta')) - await self.packet_io.send_packet(bytes([PacketType.RESET])) + await self.packet_io.send_packet(bytes(Packet(Packet.PacketType.RESET))) - packet_interval = self.pace / 1000 start_time = time.time() self.next_expected_packet_index = 0 for i in range(self.tx_packet_count): - target_time = start_time + (i * packet_interval) + target_time = start_time + (i * self.pace / 1000) now = time.time() if now < target_time: await asyncio.sleep(target_time - now) - - packet = struct.pack( - '>bbI', - PacketType.SEQUENCE, - (PACKET_FLAG_LAST if i == self.tx_packet_count - 1 else 0), - i, - ) + bytes(self.tx_packet_size - 6) + now = time.time() + + packet = bytes( + Packet( + packet_type=Packet.PacketType.SEQUENCE, + flags=( + Packet.PacketFlags.LAST + if i == self.tx_packet_count - 1 + else 0 + ), + sequence=i, + timestamp=int((now - start_time) * 1000000), + payload=bytes(self.tx_packet_size - 10), + ) + ) logging.info(color(f'Sending packet {i}', 'yellow')) - self.ping_times.append(time.time()) + self.ping_times.append(now) await self.packet_io.send_packet(packet) await self.done.wait() @@ -531,40 +646,35 @@ async def run(self): if self.repeat: logging.info(color('--- End of runs', 'blue')) - def on_packet_received(self, packet): + def on_packet_received(self, data): try: - packet_type, packet_data = parse_packet(packet) + packet = Packet.from_bytes(data) except ValueError: return - try: - packet_flags, packet_index = parse_packet_sequence(packet_data) - except ValueError: - return - - if packet_type == PacketType.ACK: - elapsed = time.time() - self.ping_times[packet_index] + if packet.packet_type == Packet.PacketType.ACK: + elapsed = time.time() - self.ping_times[packet.sequence] rtt = elapsed * 1000 self.rtts.append(rtt) logging.info( color( - f'<<< Received ACK [{packet_index}], RTT={rtt:.2f}ms', + f'<<< Received ACK [{packet.sequence}], RTT={rtt:.2f}ms', 'green', ) ) - if packet_index == self.next_expected_packet_index: + if packet.sequence == self.next_expected_packet_index: self.next_expected_packet_index += 1 else: logging.info( color( f'!!! Unexpected packet, ' f'expected {self.next_expected_packet_index} ' - f'but received {packet_index}' + f'but received {packet.sequence}' ) ) - if packet_flags & PACKET_FLAG_LAST: + if packet.flags & Packet.PacketFlags.LAST: self.done.set() return @@ -576,89 +686,56 @@ class Pong: expected_packet_index: int def __init__(self, packet_io, linger): - self.reset() + self.jitter_stats = JitterStats() self.packet_io = packet_io self.packet_io.packet_listener = self self.linger = linger self.done = asyncio.Event() + self.reset() def reset(self): self.expected_packet_index = 0 - self.receive_times = [] - - def on_packet_received(self, packet): - self.receive_times.append(time.time()) + self.jitter_stats.reset() + def on_packet_received(self, data): try: - packet_type, packet_data = parse_packet(packet) + packet = Packet.from_bytes(data) except ValueError: return - if packet_type == PacketType.RESET: + if packet.packet_type == Packet.PacketType.RESET: logging.info(color('=== Received RESET', 'magenta')) self.reset() return - try: - packet_flags, packet_index = parse_packet_sequence(packet_data) - except ValueError: - return - interval = ( - self.receive_times[-1] - self.receive_times[-2] - if len(self.receive_times) >= 2 - else 0 - ) + jitter = self.jitter_stats.on_packet_received(packet) logging.info( color( - f'<<< Received packet {packet_index}: ' - f'flags=0x{packet_flags:02X}, {len(packet)} bytes, ' - f'interval={interval:.4f}', + f'<<< Received packet {packet.sequence}: ' + f'flags={packet.flags}, {len(data)} bytes, ' + f'jitter={jitter:.4f}', 'green', ) ) - if packet_index != self.expected_packet_index: + if packet.sequence != self.expected_packet_index: logging.info( color( f'!!! Unexpected packet, expected {self.expected_packet_index} ' - f'but received {packet_index}' + f'but received {packet.sequence}' ) ) - self.expected_packet_index = packet_index + 1 + self.expected_packet_index = packet.sequence + 1 AsyncRunner.spawn( self.packet_io.send_packet( - struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index) + bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence)) ) ) - if packet_flags & PACKET_FLAG_LAST: - if len(self.receive_times) >= 3: - # Show basic stats - intervals = [ - self.receive_times[i + 1] - self.receive_times[i] - for i in range(len(self.receive_times) - 1) - ] - log_stats('Packet intervals', intervals, 3) - - # Show a histogram - bin_count = 20 - bins = [0] * bin_count - interval_min = min(intervals) - interval_max = max(intervals) - interval_range = interval_max - interval_min - bin_thresholds = [ - interval_min + i * (interval_range / bin_count) - for i in range(bin_count) - ] - for interval in intervals: - for i in reversed(range(bin_count)): - if interval >= bin_thresholds[i]: - bins[i] += 1 - break - for i in range(bin_count): - logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}') + if packet.flags & Packet.PacketFlags.LAST: + self.jitter_stats.show_stats() if not self.linger: self.done.set() @@ -1471,7 +1548,7 @@ def create_mode(device): def create_scenario_factory(ctx, default_scenario): scenario = ctx.obj['scenario'] if scenario is None: - scenarion = default_scenario + scenario = default_scenario def create_scenario(packet_io): if scenario == 'send': @@ -1605,7 +1682,7 @@ def create_scenario(packet_io): '--packet-size', '-s', metavar='SIZE', - type=click.IntRange(8, 8192), + type=click.IntRange(10, 8192), default=500, help='Packet size (send or ping scenario)', ) diff --git a/extras/android/BtBench/app/build.gradle.kts b/extras/android/BtBench/app/build.gradle.kts index 05d36e1a..887be166 100644 --- a/extras/android/BtBench/app/build.gradle.kts +++ b/extras/android/BtBench/app/build.gradle.kts @@ -10,7 +10,7 @@ android { defaultConfig { applicationId = "com.github.google.bumble.btbench" - minSdk = 30 + minSdk = 33 targetSdk = 34 versionCode = 1 versionName = "1.0" diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Connection.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Connection.kt new file mode 100644 index 00000000..7f27c831 --- /dev/null +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Connection.kt @@ -0,0 +1,109 @@ +package com.github.google.bumble.btbench + +import android.annotation.SuppressLint +import android.bluetooth.BluetoothAdapter +import android.bluetooth.BluetoothDevice +import android.bluetooth.BluetoothGatt +import android.bluetooth.BluetoothGattCallback +import android.bluetooth.BluetoothManager +import android.bluetooth.BluetoothProfile +import android.content.Context +import android.os.Build +import androidx.core.content.ContextCompat +import java.util.logging.Logger + +private val Log = Logger.getLogger("btbench.connection") + +open class Connection( + private val viewModel: AppViewModel, + private val bluetoothAdapter: BluetoothAdapter, + private val context: Context +) : BluetoothGattCallback() { + var remoteDevice: BluetoothDevice? = null + var gatt: BluetoothGatt? = null + + @SuppressLint("MissingPermission") + open fun connect() { + val addressIsPublic = viewModel.peerBluetoothAddress.endsWith("/P") + val address = viewModel.peerBluetoothAddress.take(17) + remoteDevice = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + bluetoothAdapter.getRemoteLeDevice( + address, + if (addressIsPublic) { + BluetoothDevice.ADDRESS_TYPE_PUBLIC + } else { + BluetoothDevice.ADDRESS_TYPE_RANDOM + } + ) + } else { + bluetoothAdapter.getRemoteDevice(address) + } + + gatt = remoteDevice?.connectGatt( + context, + false, + this, + BluetoothDevice.TRANSPORT_LE, + if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK + ) + } + + @SuppressLint("MissingPermission") + fun disconnect() { + gatt?.disconnect() + } + + override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) { + Log.info("MTU update: mtu=$mtu status=$status") + viewModel.mtu = mtu + } + + override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) { + Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status") + viewModel.txPhy = txPhy + viewModel.rxPhy = rxPhy + } + + override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) { + Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status") + viewModel.txPhy = txPhy + viewModel.rxPhy = rxPhy + } + + @SuppressLint("MissingPermission") + override fun onConnectionStateChange( + gatt: BluetoothGatt?, status: Int, newState: Int + ) { + if (status != BluetoothGatt.GATT_SUCCESS) { + Log.warning("onConnectionStateChange status=$status") + } + + if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) { + if (viewModel.use2mPhy) { + Log.info("requesting 2M PHY") + gatt.setPreferredPhy( + BluetoothDevice.PHY_LE_2M_MASK, + BluetoothDevice.PHY_LE_2M_MASK, + BluetoothDevice.PHY_OPTION_NO_PREFERRED + ) + } + gatt.readPhy() + + // Request an MTU update, even though we don't use GATT, because Android + // won't request a larger link layer maximum data length otherwise. + gatt.requestMtu(517) + + // Request a specific connection priority + val connectionPriority = when (viewModel.connectionPriority) { + "BALANCED" -> BluetoothGatt.CONNECTION_PRIORITY_BALANCED + "LOW_POWER" -> BluetoothGatt.CONNECTION_PRIORITY_LOW_POWER + "HIGH" -> BluetoothGatt.CONNECTION_PRIORITY_HIGH + "DCK" -> BluetoothGatt.CONNECTION_PRIORITY_DCK + else -> 0 + } + if (!gatt.requestConnectionPriority(connectionPriority)) { + Log.warning("requestConnectionPriority failed") + } + } + } +} \ No newline at end of file diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/GattClient.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/GattClient.kt new file mode 100644 index 00000000..5d60f6b5 --- /dev/null +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/GattClient.kt @@ -0,0 +1,219 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.github.google.bumble.btbench + +import android.annotation.SuppressLint +import android.bluetooth.BluetoothAdapter +import android.bluetooth.BluetoothGatt +import android.bluetooth.BluetoothGattCharacteristic +import android.bluetooth.BluetoothGattDescriptor +import android.bluetooth.BluetoothProfile +import android.content.Context +import java.io.IOException +import java.util.UUID +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Semaphore +import java.util.logging.Logger +import kotlin.concurrent.thread + +private val Log = Logger.getLogger("btbench.gatt-client") + +private var CCCD_UUID = UUID.fromString("00002902-0000-1000-8000-00805F9B34FB") + +private val SPEED_SERVICE_UUID = UUID.fromString("50DB505C-8AC4-4738-8448-3B1D9CC09CC5") +private val SPEED_TX_UUID = UUID.fromString("E789C754-41A1-45F4-A948-A0A1A90DBA53") +private val SPEED_RX_UUID = UUID.fromString("016A2CC7-E14B-4819-935F-1F56EAE4098D") + + +class GattClientConnection( + viewModel: AppViewModel, + bluetoothAdapter: BluetoothAdapter, + context: Context +) : Connection(viewModel, bluetoothAdapter, context), PacketIO { + override var packetSink: PacketSink? = null + private val discoveryDone: CountDownLatch = CountDownLatch(1) + private val writeSemaphore: Semaphore = Semaphore(1) + var rxCharacteristic: BluetoothGattCharacteristic? = null + var txCharacteristic: BluetoothGattCharacteristic? = null + + override fun connect() { + super.connect() + + // Check if we're already connected and have discovered the services + if (gatt?.getService(SPEED_SERVICE_UUID) != null) { + onServicesDiscovered(gatt, BluetoothGatt.GATT_SUCCESS) + } + } + + @SuppressLint("MissingPermission") + override fun onConnectionStateChange( + gatt: BluetoothGatt?, status: Int, newState: Int + ) { + super.onConnectionStateChange(gatt, status, newState) + if (status != BluetoothGatt.GATT_SUCCESS) { + discoveryDone.countDown() + return + } + if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) { + if (!gatt.discoverServices()) { + Log.warning("discoverServices could not start") + discoveryDone.countDown() + } + } + } + + @SuppressLint("MissingPermission") + override fun onServicesDiscovered(gatt: BluetoothGatt?, status: Int) { + if (status != BluetoothGatt.GATT_SUCCESS) { + Log.warning("failed to discover services: ${status}") + discoveryDone.countDown() + return + } + + // Find the service + val service = gatt!!.getService(SPEED_SERVICE_UUID) + if (service == null) { + Log.warning("GATT Service not found") + discoveryDone.countDown() + return + } + + // Find the RX and TX characteristics + rxCharacteristic = service.getCharacteristic(SPEED_RX_UUID) + if (rxCharacteristic == null) { + Log.warning("GATT RX Characteristics not found") + discoveryDone.countDown() + return + } + txCharacteristic = service.getCharacteristic(SPEED_TX_UUID) + if (txCharacteristic == null) { + Log.warning("GATT TX Characteristics not found") + discoveryDone.countDown() + return + } + + // Subscribe to the RX characteristic + gatt.setCharacteristicNotification(rxCharacteristic, true) + val cccdDescriptor = rxCharacteristic!!.getDescriptor(CCCD_UUID) + gatt.writeDescriptor(cccdDescriptor, BluetoothGattDescriptor.ENABLE_NOTIFICATION_VALUE); + + Log.info("GATT discovery complete") + discoveryDone.countDown() + } + + override fun onCharacteristicWrite( + gatt: BluetoothGatt?, + characteristic: BluetoothGattCharacteristic?, + status: Int + ) { + // Now we can write again + writeSemaphore.release() + + if (status != BluetoothGatt.GATT_SUCCESS) { + Log.warning("onCharacteristicWrite failed: $status") + return + } + } + + override fun onCharacteristicChanged( + gatt: BluetoothGatt, + characteristic: BluetoothGattCharacteristic, + value: ByteArray + ) { + if (characteristic.uuid == SPEED_RX_UUID && packetSink != null) { + val packet = Packet.from(value) + packetSink!!.onPacket(packet) + } + } + + @SuppressLint("MissingPermission") + override fun sendPacket(packet: Packet) { + if (txCharacteristic == null) { + Log.warning("No TX characteristic, dropping") + return + } + + // Wait until we can write + writeSemaphore.acquire() + + // Write the data + val data = packet.toBytes() + val clampedData = if (data.size > 512) { + // Clamp the data to the maximum allowed characteristic data size + data.copyOf(512) + } else { + data + } + gatt?.writeCharacteristic( + txCharacteristic!!, + clampedData, + BluetoothGattCharacteristic.WRITE_TYPE_NO_RESPONSE + ) + } + + fun waitForDiscoveryCompletion() { + discoveryDone.await() + } +} + +class GattClient( + private val viewModel: AppViewModel, + bluetoothAdapter: BluetoothAdapter, + context: Context, + private val createIoClient: (packetIo: PacketIO) -> IoClient +) : Mode { + private var connection: GattClientConnection = + GattClientConnection(viewModel, bluetoothAdapter, context) + private var clientThread: Thread? = null + + @SuppressLint("MissingPermission") + override fun run() { + viewModel.running = true + + clientThread = thread(name = "GattClient") { + connection.connect() + + viewModel.aborter = { + connection.disconnect() + } + + // Discover the rx and tx characteristics + connection.waitForDiscoveryCompletion() + if (connection.rxCharacteristic == null || connection.txCharacteristic == null) { + connection.disconnect() + viewModel.running = false + return@thread + } + + val ioClient = createIoClient(connection) + + try { + ioClient.run() + viewModel.status = "OK" + } catch (error: IOException) { + Log.info("run ended abruptly") + viewModel.status = "ABORTED" + viewModel.lastError = "IO_ERROR" + } finally { + connection.disconnect() + viewModel.running = false + } + } + } + + override fun waitForCompletion() { + clientThread?.join() + } +} diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt index 5a4cc3c7..10494117 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt @@ -16,101 +16,25 @@ package com.github.google.bumble.btbench import android.annotation.SuppressLint import android.bluetooth.BluetoothAdapter -import android.bluetooth.BluetoothDevice -import android.bluetooth.BluetoothGatt -import android.bluetooth.BluetoothGattCallback -import android.bluetooth.BluetoothProfile import android.content.Context -import android.os.Build import java.util.logging.Logger private val Log = Logger.getLogger("btbench.l2cap-client") class L2capClient( private val viewModel: AppViewModel, - private val bluetoothAdapter: BluetoothAdapter, - private val context: Context, + bluetoothAdapter: BluetoothAdapter, + context: Context, private val createIoClient: (packetIo: PacketIO) -> IoClient ) : Mode { + private var connection: Connection = Connection(viewModel, bluetoothAdapter, context) private var socketClient: SocketClient? = null @SuppressLint("MissingPermission") override fun run() { viewModel.running = true - val addressIsPublic = viewModel.peerBluetoothAddress.endsWith("/P") - val address = viewModel.peerBluetoothAddress.take(17) - val remoteDevice = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { - bluetoothAdapter.getRemoteLeDevice( - address, - if (addressIsPublic) { - BluetoothDevice.ADDRESS_TYPE_PUBLIC - } else { - BluetoothDevice.ADDRESS_TYPE_RANDOM - } - ) - } else { - bluetoothAdapter.getRemoteDevice(address) - } - - val gatt = remoteDevice.connectGatt( - context, - false, - object : BluetoothGattCallback() { - override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) { - Log.info("MTU update: mtu=$mtu status=$status") - viewModel.mtu = mtu - } - - override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) { - Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status") - viewModel.txPhy = txPhy - viewModel.rxPhy = rxPhy - } - - override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) { - Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status") - viewModel.txPhy = txPhy - viewModel.rxPhy = rxPhy - } - - override fun onConnectionStateChange( - gatt: BluetoothGatt?, status: Int, newState: Int - ) { - if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) { - if (viewModel.use2mPhy) { - Log.info("requesting 2M PHY") - gatt.setPreferredPhy( - BluetoothDevice.PHY_LE_2M_MASK, - BluetoothDevice.PHY_LE_2M_MASK, - BluetoothDevice.PHY_OPTION_NO_PREFERRED - ) - } - gatt.readPhy() - - // Request an MTU update, even though we don't use GATT, because Android - // won't request a larger link layer maximum data length otherwise. - gatt.requestMtu(517) - - // Request a specific connection priority - val connectionPriority = when (viewModel.connectionPriority) { - "BALANCED" -> BluetoothGatt.CONNECTION_PRIORITY_BALANCED - "LOW_POWER" -> BluetoothGatt.CONNECTION_PRIORITY_LOW_POWER - "HIGH" -> BluetoothGatt.CONNECTION_PRIORITY_HIGH - "DCK" -> BluetoothGatt.CONNECTION_PRIORITY_DCK - else -> 0 - } - if (!gatt.requestConnectionPriority(connectionPriority)) { - Log.warning("requestConnectionPriority failed") - } - } - } - }, - BluetoothDevice.TRANSPORT_LE, - if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK - ) - - val socket = remoteDevice.createInsecureL2capChannel(viewModel.l2capPsm) - + connection.connect() + val socket = connection.remoteDevice!!.createInsecureL2capChannel(viewModel.l2capPsm) socketClient = SocketClient(viewModel, socket, createIoClient) socketClient!!.run() } diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt index df5c53cb..131a26fc 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt @@ -146,9 +146,7 @@ class MainActivity : ComponentActivity() { initBluetooth() setContent { MainView( - appViewModel, - ::becomeDiscoverable, - ::runScenario + appViewModel, ::becomeDiscoverable, ::runScenario ) } @@ -184,6 +182,8 @@ class MainActivity : ComponentActivity() { "rfcomm-server" -> appViewModel.mode = RFCOMM_SERVER_MODE "l2cap-client" -> appViewModel.mode = L2CAP_CLIENT_MODE "l2cap-server" -> appViewModel.mode = L2CAP_SERVER_MODE + "gatt-client" -> appViewModel.mode = GATT_CLIENT_MODE + "gatt-server" -> appViewModel.mode = GATT_SERVER_MODE } } intent.getStringExtra("autostart")?.let { @@ -204,12 +204,14 @@ class MainActivity : ComponentActivity() { RFCOMM_CLIENT_MODE -> RfcommClient(appViewModel, bluetoothAdapter!!, ::createIoClient) RFCOMM_SERVER_MODE -> RfcommServer(appViewModel, bluetoothAdapter!!, ::createIoClient) L2CAP_CLIENT_MODE -> L2capClient( - appViewModel, - bluetoothAdapter!!, - baseContext, - ::createIoClient + appViewModel, bluetoothAdapter!!, baseContext, ::createIoClient ) + L2CAP_SERVER_MODE -> L2capServer(appViewModel, bluetoothAdapter!!, ::createIoClient) + GATT_CLIENT_MODE -> GattClient( + appViewModel, bluetoothAdapter!!, baseContext, ::createIoClient + ) + else -> throw IllegalStateException() } runner.run() @@ -283,7 +285,7 @@ fun MainView( keyboardController?.hide() focusManager.clearFocus() }), - enabled = (appViewModel.mode == RFCOMM_CLIENT_MODE) or (appViewModel.mode == L2CAP_CLIENT_MODE) + enabled = (appViewModel.mode == RFCOMM_CLIENT_MODE || appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == GATT_CLIENT_MODE) ) Divider() TextField( @@ -351,43 +353,36 @@ fun MainView( keyboardController?.hide() focusManager.clearFocus() }), - enabled = (appViewModel.scenario == PING_SCENARIO) + enabled = (appViewModel.scenario == PING_SCENARIO || appViewModel.scenario == SEND_SCENARIO) ) Divider() - ActionButton( - text = "Become Discoverable", onClick = becomeDiscoverable, true - ) Row( horizontalArrangement = Arrangement.SpaceBetween, verticalAlignment = Alignment.CenterVertically ) { Text(text = "2M PHY") Spacer(modifier = Modifier.padding(start = 8.dp)) - Switch( - enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE), + Switch(enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE || appViewModel.mode == GATT_CLIENT_MODE || appViewModel.mode == GATT_SERVER_MODE), checked = appViewModel.use2mPhy, - onCheckedChange = { appViewModel.use2mPhy = it } - ) + onCheckedChange = { appViewModel.use2mPhy = it }) Column(Modifier.selectableGroup()) { listOf( - "BALANCED", - "LOW", - "HIGH", - "DCK" + "BALANCED", "LOW", "HIGH", "DCK" ).forEach { text -> Row( Modifier .selectable( selected = (text == appViewModel.connectionPriority), onClick = { appViewModel.updateConnectionPriority(text) }, - role = Role.RadioButton + role = Role.RadioButton, ) .padding(horizontal = 16.dp), verticalAlignment = Alignment.CenterVertically ) { RadioButton( selected = (text == appViewModel.connectionPriority), - onClick = null + onClick = null, + enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE || appViewModel.mode == GATT_CLIENT_MODE || appViewModel.mode == GATT_SERVER_MODE) ) Text( text = text, @@ -404,7 +399,9 @@ fun MainView( RFCOMM_CLIENT_MODE, RFCOMM_SERVER_MODE, L2CAP_CLIENT_MODE, - L2CAP_SERVER_MODE + L2CAP_SERVER_MODE, + GATT_CLIENT_MODE, + GATT_SERVER_MODE ).forEach { text -> Row( Modifier @@ -417,8 +414,7 @@ fun MainView( verticalAlignment = Alignment.CenterVertically ) { RadioButton( - selected = (text == appViewModel.mode), - onClick = null + selected = (text == appViewModel.mode), onClick = null ) Text( text = text, @@ -430,10 +426,7 @@ fun MainView( } Column(Modifier.selectableGroup()) { listOf( - SEND_SCENARIO, - RECEIVE_SCENARIO, - PING_SCENARIO, - PONG_SCENARIO + SEND_SCENARIO, RECEIVE_SCENARIO, PING_SCENARIO, PONG_SCENARIO ).forEach { text -> Row( Modifier @@ -446,8 +439,7 @@ fun MainView( verticalAlignment = Alignment.CenterVertically ) { RadioButton( - selected = (text == appViewModel.scenario), - onClick = null + selected = (text == appViewModel.scenario), onClick = null ) Text( text = text, @@ -465,20 +457,29 @@ fun MainView( ActionButton( text = "Stop", onClick = appViewModel::abort, enabled = appViewModel.running ) + ActionButton( + text = "Become Discoverable", onClick = becomeDiscoverable, true + ) } Divider() - Text( - text = if (appViewModel.mtu != 0) "MTU: ${appViewModel.mtu}" else "" - ) - Text( - text = if (appViewModel.rxPhy != 0 || appViewModel.txPhy != 0) "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}" else "" - ) + if (appViewModel.mtu != 0) { + Text( + text = "MTU: ${appViewModel.mtu}" + ) + } + if (appViewModel.rxPhy != 0) { + Text( + text = "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}" + ) + } Text( text = "Status: ${appViewModel.status}" ) - Text( - text = "Last Error: ${appViewModel.lastError}" - ) + if (appViewModel.lastError.isNotEmpty()) { + Text( + text = "Last Error: ${appViewModel.lastError}" + ) + } Text( text = "Packets Sent: ${appViewModel.packetsSent}" ) diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt index b15c4feb..5b699496 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt @@ -35,6 +35,8 @@ const val L2CAP_CLIENT_MODE = "L2CAP Client" const val L2CAP_SERVER_MODE = "L2CAP Server" const val RFCOMM_CLIENT_MODE = "RFCOMM Client" const val RFCOMM_SERVER_MODE = "RFCOMM Server" +const val GATT_CLIENT_MODE = "GATT Client" +const val GATT_SERVER_MODE = "GATT Server" const val SEND_SCENARIO = "Send" const val RECEIVE_SCENARIO = "Receive" diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Packet.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Packet.kt index 0ccd8cf4..b1cf1741 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Packet.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Packet.kt @@ -17,6 +17,7 @@ package com.github.google.bumble.btbench import android.bluetooth.BluetoothSocket import java.io.IOException import java.nio.ByteBuffer +import java.nio.ByteOrder import java.util.logging.Logger import kotlin.math.min @@ -37,11 +38,16 @@ abstract class Packet(val type: Int, val payload: ByteArray = ByteArray(0)) { RESET -> ResetPacket() SEQUENCE -> SequencePacket( data[1].toInt(), - ByteBuffer.wrap(data, 2, 4).getInt(), - data.sliceArray(6.. AckPacket( + data[1].toInt(), + ByteBuffer.wrap(data, 2, 4).order(ByteOrder.LITTLE_ENDIAN).getInt() ) - ACK -> AckPacket(data[1].toInt(), ByteBuffer.wrap(data, 2, 4).getInt()) else -> GenericPacket(data[0].toInt(), data.sliceArray(1.. 0) { + val targetTime = startTime + (i * viewModel.senderPacketInterval).milliseconds + val delay = targetTime - now + if (delay.isPositive()) { + Log.info("sleeping ${delay.inWholeMilliseconds} ms") + Thread.sleep(delay.inWholeMilliseconds) + now = TimeSource.Monotonic.markNow() + } } pingTimes.add(TimeSource.Monotonic.markNow()) packetIO.sendPacket( SequencePacket( if (i < packetCount - 1) 0 else Packet.LAST_FLAG, i, - ByteArray(packetSize - 6) + (now - startTime).inWholeMicroseconds.toInt(), + ByteArray(packetSize - 10) ) ) viewModel.packetsSent = i + 1 diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Sender.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Sender.kt index d248f3b9..50af5538 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Sender.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Sender.kt @@ -16,6 +16,7 @@ package com.github.google.bumble.btbench import java.util.concurrent.Semaphore import java.util.logging.Logger +import kotlin.time.Duration.Companion.milliseconds import kotlin.time.DurationUnit import kotlin.time.TimeSource @@ -45,20 +46,32 @@ class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO val packetCount = viewModel.senderPacketCount val packetSize = viewModel.senderPacketSize - for (i in 0.. 0) { + val targetTime = startTime + (i * viewModel.senderPacketInterval).milliseconds + val delay = targetTime - now + if (delay.isPositive()) { + Log.info("sleeping ${delay.inWholeMilliseconds} ms") + Thread.sleep(delay.inWholeMilliseconds) + } + now = TimeSource.Monotonic.markNow() + } + val flags = when (i) { + packetCount - 1 -> Packet.LAST_FLAG + else -> 0 + } + packetIO.sendPacket( + SequencePacket( + flags, + i, + (now - startTime).inWholeMicroseconds.toInt(), + ByteArray(packetSize - 10) + ) + ) bytesSent += packetSize viewModel.packetsSent = i + 1 } - packetIO.sendPacket( - SequencePacket( - Packet.LAST_FLAG, - packetCount - 1, - ByteArray(packetSize - 6) - ) - ) - bytesSent += packetSize - viewModel.packetsSent = packetCount // Wait for the ACK Log.info("waiting for ACK")