Replies: 4 comments 5 replies
-
Hi @zucher, could you provide the full stacktrace so I can know where the error is coming from exactly? |
Beta Was this translation helpful? Give feedback.
-
Pretty sure that the |
Beta Was this translation helpful? Give feedback.
-
@zucher I just realized that the dimensions must be inversed in your example. From
So you should simply provide an |
Beta Was this translation helpful? Give feedback.
-
Hi @juanmc2005 , thank you for your reply, unfortunatly after swap , I have also an issue: with audio_frame = av.audio.resampler.AudioResampler(layout="mono").resample(audio_frame)[0]
self.stream.on_next(audio_frame.to_ndarray().swapaxes(0,1)) My code: Aiortc audio track interception import asyncio
from aiortc import MediaStreamTrack
from aiortc.mediastreams import MediaStreamError
from pathlib import Path
import logging.handlers
import logging
import pydub
from diart.blocks.diarization import OnlineSpeakerDiarization
from diart.inference import RealTimeInference
from diart.sinks import RTTMWriter
from .WebRTCAudioSource import WebRTCAudioSource
HERE = Path(__file__).parent
logger = logging.getLogger(__name__)
class AudioDiarization(MediaStreamTrack):
"""
A audio stream track that only listen.
"""
kind = "audio"
def __init__(self, track, channel, transform, event_emitter):
super().__init__() # don't forget this!
self.track = track
self.transform = transform
self.ee = event_emitter
self.channel = channel
self.diart_init()
self.sound_chunk = pydub.AudioSegment.empty()
self.silent_count = 0
@staticmethod
def create_transformer(track, channel, transform, event_emitter):
return AudioDiarization(track, channel, transform, event_emitter)
def diart_init( self):
pipeline = OnlineSpeakerDiarization()
self.source = WebRTCAudioSource("", 48000)
inference = RealTimeInference(pipeline, self.source)
inference.attach_hooks(lambda ann_wav:
logger.info(ann_wav[0].to_rttm())
)
prediction = inference()
async def recv(self):
try:
if self.track.readyState != "live":
raise MediaStreamError
audio_frame = await self.track.recv()
self.source.push_audio_frame(audio_frame)
return audio_frame
except Exception as e:
if self.track.readyState == 'ended':
raise e
else:
logger.error(e) WebRTCAudioSource for diart integration from diart.sources import AudioSource
from typing import Text, Optional, AnyStr, Dict, Any, Union, Tuple
from rx.subject import Subject
from av.frame import Frame
import av
import numpy as np
import pydub
class WebRTCAudioSource(AudioSource):
def __init__(self, uri: Text, sample_rate: int):
super().__init__(uri, sample_rate)
self.stream = Subject()
@property
def duration(self) -> Optional[float]:
"""The duration of the stream if known. Defaults to None (unknown duration)."""
return None
def read(self):
"""Start reading the source and yielding samples through the stream."""
pass
def close(self):
"""Stop reading the source and close all open streams."""
pass
def push_audio_frame(self, audio_frame: Frame):
audio_frame = av.audio.resampler.AudioResampler(layout="mono").resample(audio_frame)[0]
self.stream.on_next(audio_frame.to_ndarray().swapaxes(0,1)) The last line is failing |
Beta Was this translation helpful? Give feedback.
-
Hi,
I'm currently creating a service allowing to retrieve the AudioFrame generate by the aiortc webrtc stack, in order to have a live diarization.
I converting the AudioFrame into mono, and extracting the associated ndarray thanks to:
the resulting shape is (1,960)
The resulting error is:
Waveform must have shape (1, samples) but (1, 1, 960) was found
when I squeeze it:
I have the following error:
Temporal features must be 2D or 3D
So I don't know how to pass over this error, someone already try to do some similar stuff ?
Thanks
Beta Was this translation helpful? Give feedback.
All reactions