Skip to content

Commit

Permalink
Replace liblc3 wasm library
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Dec 18, 2024
1 parent 6ae3f09 commit 86ecd81
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 184 deletions.
215 changes: 32 additions & 183 deletions apps/lea_unicast/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,18 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations

import asyncio
import datetime
import enum
import functools
from importlib import resources
import json
import os
import logging
import pathlib
from typing import Optional, List, cast
import weakref
import struct

import ctypes
import wasmtime
import wasmtime.loader
import liblc3 # type: ignore
import lc3
import wave

import click
import aiohttp.web
Expand All @@ -45,6 +40,7 @@
from bumble.profiles import ascs, bap, pacs
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket


# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
Expand All @@ -54,6 +50,7 @@
# Constants
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
DEFAULT_PCM_BYTES_PER_SAMPLE = 2


def _sink_pac_record() -> pacs.PacRecord:
Expand Down Expand Up @@ -100,153 +97,8 @@ def _source_pac_record() -> pacs.PacRecord:
)


# -----------------------------------------------------------------------------
# WASM - liblc3
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)


class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3


MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)

DECODER_STACK_POINTER = STACK_POINTER
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2


encoders: List[int] = []
decoders: List[int] = []


def setup_encoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
encoders[:num_channels] = [
liblc3.lc3_setup_encoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
)
for i in range(num_channels)
]


def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]


def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''

input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels

# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore

output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)

for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)

if res != 0:
logging.error(f"Parsing failed, res={res}")

# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)


def encode(
sdu_length: int,
num_channels: int,
stride: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''

input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)

# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore

output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = sdu_length
output_frame_size = output_buffer_size // num_channels

for i in range(num_channels):
res = liblc3.lc3_encode(
encoders[i],
DEFAULT_PCM_FORMAT,
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
stride,
output_frame_size,
output_buffer_offset + output_frame_size * i,
)

if res != 0:
logging.error(f"Parsing failed, res={res}")

# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
decoder: lc3.Decoder | None = None
encoder: lc3.Encoder | None = None


async def lc3_source_task(
Expand All @@ -256,28 +108,24 @@ async def lc3_source_task(
device: Device,
cis_handle: int,
) -> None:
with open(filename, 'rb') as f:
header = f.read(44)
assert header[8:12] == b'WAVE'

pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
struct.unpack("<HIIHH", header[22:36])
)
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
assert encoder
with wave.open(filename, 'rb') as wav:
assert (
bits_per_sample := wav.getsampwidth()
) == DEFAULT_PCM_BYTES_PER_SAMPLE * 8

frame_bytes = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)
frame_samples = encoder.get_frame_samples()
packet_sequence_number = 0

while True:
next_round = datetime.datetime.now() + datetime.timedelta(
microseconds=frame_duration_us
)
pcm_data = f.read(frame_bytes)
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
sdu = encoder.encode(
pcm=wav.readframes(frame_samples),
num_bytes=sdu_length,
bit_depth=bits_per_sample,
)

iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
Expand Down Expand Up @@ -410,7 +258,7 @@ class Speaker:

def __init__(
self,
device_config_path: Optional[str],
device_config_path: str | None,
ui_port: int,
transport: str,
lc3_input_file_path: str,
Expand Down Expand Up @@ -490,12 +338,11 @@ def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
not isinstance(codec_config, bap.CodecSpecificConfiguration)
or codec_config.frame_duration is None
or codec_config.audio_channel_allocation is None
or decoder is None
):
return
pcm = decode(
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
pdu.iso_sdu_fragment,
pcm = decoder.decode(
pdu.iso_sdu_fragment, bit_depth=DEFAULT_PCM_BYTES_PER_SAMPLE * 8
)
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))

Expand Down Expand Up @@ -537,16 +384,18 @@ def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
):
return
if ase.role == ascs.AudioRole.SOURCE:
setup_encoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
global encoder
encoder = lc3.Encoder(
frame_duration_us=codec_config.frame_duration.us,
sample_rate_hz=codec_config.sampling_frequency.hz,
num_channels=codec_config.audio_channel_allocation.channel_count,
)
else:
setup_decoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
global decoder
decoder = lc3.Decoder(
frame_duration_us=codec_config.frame_duration.us,
sample_rate_hz=codec_config.sampling_frequency.hz,
num_channels=codec_config.audio_channel_allocation.channel_count,
)

for ase in ascs_service.ase_state_machines.values():
Expand Down
Binary file removed apps/lea_unicast/liblc3.wasm
Binary file not shown.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ development = [
"types-appdirs >= 1.4.3",
"types-invoke >= 1.7.3",
"types-protobuf >= 4.21.0",
"wasmtime == 20.0.0",
]
avatar = [
"pandora-avatar == 0.0.10",
Expand All @@ -66,6 +65,9 @@ documentation = [
"mkdocs-material >= 8.5.6",
"mkdocstrings[python] >= 0.19.0",
]
lc3 = [
"lc3 @ git+https://github.com/google/liblc3.git",
]

[project.scripts]
bumble-ble-rpa-tool = "bumble.apps.ble_rpa_tool:main"
Expand Down

0 comments on commit 86ecd81

Please sign in to comment.