Skip to content

Commit

Permalink
BIG Sync app
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Dec 17, 2024
1 parent a3bb51a commit 417141d
Showing 1 changed file with 262 additions and 12 deletions.
274 changes: 262 additions & 12 deletions apps/auracast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,20 @@
import asyncio
import contextlib
import dataclasses
import enum
import functools
import logging
import os
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple

import click
import pyee

import ctypes
import wasmtime
import wasmtime.loader
from lea_unicast import liblc3 # type: ignore # pylint: disable=E0401

from bumble.colors import color
import bumble.company_ids
import bumble.core
Expand Down Expand Up @@ -54,6 +61,95 @@
AURACAST_DEFAULT_ATT_MTU = 256


# -----------------------------------------------------------------------------
# WASM - liblc3
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)


class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3


MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)

DECODER_STACK_POINTER = STACK_POINTER
DECODE_BUFFER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.FLOAT
DEFAULT_PCM_BYTES_PER_SAMPLE = 4

decoders = list[int]()


def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]


def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''

input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels

# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore

output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)

for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)

if res != 0:
logging.error(f"Parsing failed, res={res}")

# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)


# -----------------------------------------------------------------------------
# Scan For Broadcasts
# -----------------------------------------------------------------------------
Expand All @@ -62,6 +158,7 @@ class BroadcastScanner(pyee.EventEmitter):
class Broadcast(pyee.EventEmitter):
name: str | None
sync: bumble.device.PeriodicAdvertisingSync
broadcast_id: int
rssi: int = 0
public_broadcast_announcement: Optional[
bumble.profiles.pbp.PublicBroadcastAnnouncement
Expand Down Expand Up @@ -280,11 +377,14 @@ def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
bumble.core.AdvertisingData.SERVICE_DATA_16_BIT_UUID
)
) or not (
any(
ad
for ad in ads
if isinstance(ad, tuple)
and ad[0] == bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
broadcast_audio_announcement := next(
(
ad
for ad in ads
if isinstance(ad, tuple)
and ad[0] == bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
),
None,
)
):
return
Expand All @@ -293,25 +393,35 @@ def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
bumble.core.AdvertisingData.BROADCAST_NAME
)
assert isinstance(broadcast_name, str) or broadcast_name is None
assert isinstance(broadcast_audio_announcement[1], bytes)

if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement)
return

bumble.utils.AsyncRunner.spawn(
self.on_new_broadcast(broadcast_name, advertisement)
self.on_new_broadcast(
broadcast_name,
advertisement,
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(
broadcast_audio_announcement[1]
).broadcast_id,
)
)

async def on_new_broadcast(
self, name: str | None, advertisement: bumble.device.Advertisement
self,
name: str | None,
advertisement: bumble.device.Advertisement,
broadcast_id: int,
) -> None:
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
advertiser_address=advertisement.address,
sid=advertisement.sid,
sync_timeout=self.sync_timeout,
filter_duplicates=self.filter_duplicates,
)
broadcast = self.Broadcast(name, periodic_advertising_sync)
broadcast = self.Broadcast(name, periodic_advertising_sync, broadcast_id)
broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
Expand All @@ -323,10 +433,11 @@ def on_broadcast_loss(self, broadcast: Broadcast) -> None:
self.emit('broadcast_loss', broadcast)


class PrintingBroadcastScanner:
class PrintingBroadcastScanner(pyee.EventEmitter):
def __init__(
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
) -> None:
super().__init__()
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
self.scanner.on('new_broadcast', self.on_new_broadcast)
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
Expand Down Expand Up @@ -610,6 +721,116 @@ async def run_pair(transport: str, address: str) -> None:
print("+++ Paired")


async def run_receive(
transport: str,
broadcast_id: int,
broadcast_code: str | None,
sync_timeout: float,
subgroup_index: int,
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return

scanner = BroadcastScanner(device, False, sync_timeout)
scan_result: asyncio.Future[BroadcastScanner.Broadcast] = (
asyncio.get_running_loop().create_future()
)

def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
if scan_result.done():
return
if broadcast.broadcast_id == broadcast_id:
scan_result.set_result(broadcast)

scanner.on('new_broadcast', on_new_broadcast)
await scanner.start()
print('Start scanning...')
broadcast = await scan_result
print('Advertisement found:')
broadcast.print()
basic_audio_announcement_scanned = asyncio.Event()

def on_change() -> None:
if (
broadcast.basic_audio_announcement
and not basic_audio_announcement_scanned.is_set()
):
basic_audio_announcement_scanned.set()

broadcast.on('change', on_change)
if not broadcast.basic_audio_announcement:
print('Wait for Basic Audio Announcement...')
await basic_audio_announcement_scanned.wait()
print('Basic Audio Announcement found')
broadcast.print()
print('Stop scanning')
await scanner.stop()
print('Start sync to BIG')

assert broadcast.basic_audio_announcement
subgroup = broadcast.basic_audio_announcement.subgroups[subgroup_index]
configuration = subgroup.codec_specific_configuration
assert configuration
assert (sampling_frequency := configuration.sampling_frequency)
assert (frame_duration := configuration.frame_duration)

big_sync = await device.create_big_sync(
broadcast.sync,
bumble.device.BigSyncParameters(
big_sync_timeout=0x4000,
bis=[bis.index for bis in subgroup.bis],
broadcast_code=(
bytes.fromhex(broadcast_code) if broadcast_code else None
),
),
)
num_bis = len(big_sync.bis_links)
setup_decoders(
sampling_frequency.hz,
frame_duration.us,
num_bis,
)
sdus = [b''] * num_bis
subprocess = await asyncio.create_subprocess_shell(
f'stdbuf -i0 ffplay -ar {DEFAULT_PCM_SAMPLE_RATE} -ac {num_bis} -f f32le pipe:0',
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
for i, bis_link in enumerate(big_sync.bis_links):
print(f'Setup ISO for BIS {bis_link.handle}')

def sink(index: int, packet: bumble.hci.HCI_IsoDataPacket):
nonlocal sdus
sdus[index] = packet.iso_sdu_fragment
if all(sdus) and subprocess.stdin:
subprocess.stdin.write(
decode(frame_duration.us, num_bis, b''.join(sdus))
)
sdus = [b''] * num_bis

bis_link.sink = functools.partial(sink, i)
await device.send_command(
bumble.hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=bis_link.handle,
data_path_direction=bumble.hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST,
data_path_id=0,
codec_id=bumble.hci.CodingFormat(
codec_id=bumble.hci.CodecID.TRANSPARENT
),
controller_delay=0,
codec_configuration=b'',
),
check_result=True,
)

terminated = asyncio.Event()
big_sync.on(big_sync.Event.TERMINATION, lambda _: terminated.set())
await terminated.wait()


def run_async(async_command: Coroutine) -> None:
try:
asyncio.run(async_command)
Expand All @@ -631,9 +852,7 @@ def run_async(async_command: Coroutine) -> None:
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
def auracast(
ctx,
):
def auracast(ctx):
ctx.ensure_object(dict)


Expand Down Expand Up @@ -691,6 +910,37 @@ def pair(ctx, transport, address):
run_async(run_pair(transport, address))


@auracast.command('receive')
@click.argument('transport')
@click.argument('broadcast_id', type=int)
@click.option(
'--broadcast-code',
metavar='BROADCAST_CODE',
type=str,
help='Broadcast encryption code in hex format',
)
@click.option(
'--sync-timeout',
metavar='SYNC_TIMEOUT',
type=float,
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
help='Sync timeout (in seconds)',
)
@click.option(
'--subgroup',
metavar='SUBGROUP',
type=int,
default=0,
help='Index of Subgroup',
)
@click.pass_context
def receive(ctx, transport, broadcast_id, broadcast_code, sync_timeout, subgroup):
"""Receive a broadcast source"""
run_async(
run_receive(transport, broadcast_id, broadcast_code, sync_timeout, subgroup)
)


def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
auracast()
Expand Down

0 comments on commit 417141d

Please sign in to comment.