From 46707abd072e84f035b8f2763d9749b85c645076 Mon Sep 17 00:00:00 2001 From: xx205 Date: Sun, 11 Aug 2024 15:51:11 +0000 Subject: [PATCH 01/12] Add diarization recipe v3 --- examples/voxconverse/v3/README.md | 33 ++++ examples/voxconverse/v3/local/extract_emb.sh | 63 +++++++ examples/voxconverse/v3/local/make_fbank.sh | 52 ++++++ examples/voxconverse/v3/path.sh | 5 + examples/voxconverse/v3/run.sh | 186 +++++++++++++++++++ examples/voxconverse/v3/tools | 1 + examples/voxconverse/v3/wespeaker | 1 + wespeaker/diar/pahc.py | 165 ++++++++++++++++ wespeaker/diar/umap_clusterer.py | 114 ++++++++++++ 9 files changed, 620 insertions(+) create mode 100644 examples/voxconverse/v3/README.md create mode 100755 examples/voxconverse/v3/local/extract_emb.sh create mode 100755 examples/voxconverse/v3/local/make_fbank.sh create mode 100644 examples/voxconverse/v3/path.sh create mode 100755 examples/voxconverse/v3/run.sh create mode 120000 examples/voxconverse/v3/tools create mode 120000 examples/voxconverse/v3/wespeaker create mode 100644 wespeaker/diar/pahc.py create mode 100644 wespeaker/diar/umap_clusterer.py diff --git a/examples/voxconverse/v3/README.md b/examples/voxconverse/v3/README.md new file mode 100644 index 00000000..02f41fa1 --- /dev/null +++ b/examples/voxconverse/v3/README.md @@ -0,0 +1,33 @@ +## Overview + +* We suggest to run this recipe on a gpu-available machine, with onnxruntime-gpu supported. +* Dataset: voxconverse_dev that consists of 216 utterances +* Speaker model: ResNet34 model pretrained by wespeaker + * Refer to [voxceleb sv recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v2) + * [pretrained model path](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx) +* Speaker activity detection model: oracle SAD (from ground truth annotation) or system SAD (VAD model pretrained by silero, https://github.com/snakers4/silero-vad) +* Clustering method: spectral clustering +* Metric: DER = MISS + FALSE ALARM + SPEAKER CONFUSION (%) + +## Results + +* Dev set + + | system | MISS | FA | SC | DER | + |:---|:---:|:---:|:---:|:---:| + | This repo (with oracle SAD) | 2.3 | 0.0 | 2.1 | 4.4 | + | This repo (with system SAD) | 3.7 | 0.8 | 2.2 | 6.8 | + | DIHARD 2019 baseline [^1] | 11.1 | 1.4 | 11.3 | 23.8 | + | DIHARD 2019 baseline w/ SE [^1] | 9.3 | 1.3 | 9.7 | 20.2 | + | (SyncNet ASD only) [^1] | 2.2 | 4.1 | 4.0 | 10.4 | + | (AVSE ASD only) [^1] | 2.0 | 5.9 | 4.6 | 12.4 | + | (proposed) [^1] | 2.4 | 2.3 | 3.0 | 7.7 | + +* Test set + + | system | MISS | FA | SC | DER | + |:---|:---:|:---:|:---:|:---:| + | This repo (with system SAD) | 4.0 | 2.4 | 3.4 | 9.8 | + + +[^1]: Spot the conversation: speaker diarisation in the wild, https://arxiv.org/pdf/2007.01216.pdf diff --git a/examples/voxconverse/v3/local/extract_emb.sh b/examples/voxconverse/v3/local/extract_emb.sh new file mode 100755 index 00000000..b12a1c04 --- /dev/null +++ b/examples/voxconverse/v3/local/extract_emb.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright (c) 2022 Zhengyang Chen (chenzhengyang117@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +. ./path.sh || exit 1 + +scp='' +pretrained_model='' +device=cuda +store_dir='' +subseg_cmn=true +nj=1 + +batch_size=96 +frame_shift=10 +window_secs=1.5 +period_secs=0.75 + +. tools/parse_options.sh + +split_dir=$store_dir/split_scp +log_dir=$store_dir/log +mkdir -p $split_dir +mkdir -p $log_dir + +# split the scp file to sub_file, and we can use multi-process to extract embeddings +file_len=`wc -l $scp | awk '{print $1}'` +subfile_len=$[$file_len / $nj + 1] +prefix='split' +split -l $subfile_len -d -a 3 $scp ${split_dir}/${prefix}_scp_ + +for suffix in `seq 0 $[$nj-1]`;do + suffix=`printf '%03d' $suffix` + scp_subfile=${split_dir}/${prefix}_scp_${suffix} + write_ark=$store_dir/emb_${suffix}.ark + python3 wespeaker/diar/extract_emb.py \ + --scp ${scp_subfile} \ + --ark-path ${write_ark} \ + --source ${pretrained_model} \ + --device ${device} \ + --batch-size ${batch_size} \ + --frame-shift ${frame_shift} \ + --window-secs ${window_secs} \ + --period-secs ${period_secs} \ + --subseg-cmn ${subseg_cmn} \ + > ${log_dir}/${prefix}.${suffix}.log 2>&1 & +done + +wait + +cat $store_dir/emb_*.scp > $store_dir/emb.scp +echo "Finish extract embedding." diff --git a/examples/voxconverse/v3/local/make_fbank.sh b/examples/voxconverse/v3/local/make_fbank.sh new file mode 100755 index 00000000..3224c151 --- /dev/null +++ b/examples/voxconverse/v3/local/make_fbank.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Copyright (c) 2022 Zhengyang Chen (chenzhengyang117@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +. ./path.sh || exit 1 + +scp='' +segments='' +store_dir='' +subseg_cmn=true +nj=1 + +. tools/parse_options.sh + +split_dir=$store_dir/split_scp +log_dir=$store_dir/log +mkdir -p $split_dir +mkdir -p $log_dir + +# split the scp file to sub_file, and we can use multi-process to extract Fbank feature +file_len=`wc -l $scp | awk '{print $1}'` +subfile_len=$[$file_len / $nj + 1] +prefix='split' +split -l $subfile_len -d -a 3 $scp ${split_dir}/${prefix}_scp_ + +for suffix in `seq 0 $[$nj-1]`;do + suffix=`printf '%03d' $suffix` + scp_subfile=${split_dir}/${prefix}_scp_${suffix} + write_ark=$store_dir/fbank_${suffix}.ark + python3 wespeaker/diar/make_fbank.py \ + --scp ${scp_subfile} \ + --segments ${segments} \ + --ark-path ${write_ark} \ + --subseg-cmn ${subseg_cmn} \ + > ${log_dir}/${prefix}.${suffix}.log 2>&1 & +done + +wait + +cat $store_dir/fbank_*.scp > $store_dir/fbank.scp +echo "Finish make Fbank." diff --git a/examples/voxconverse/v3/path.sh b/examples/voxconverse/v3/path.sh new file mode 100644 index 00000000..b90a5154 --- /dev/null +++ b/examples/voxconverse/v3/path.sh @@ -0,0 +1,5 @@ +export PATH=$PWD:$PATH + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:$PYTHONPATH diff --git a/examples/voxconverse/v3/run.sh b/examples/voxconverse/v3/run.sh new file mode 100755 index 00000000..f53cfab5 --- /dev/null +++ b/examples/voxconverse/v3/run.sh @@ -0,0 +1,186 @@ +#!/bin/bash +# Copyright (c) 2022-2023 Xu Xiang +# 2022 Zhengyang Chen (chenzhengyang117@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +. ./path.sh || exit 1 + +stage=-1 +stop_stage=-1 +sad_type="oracle" +partition="dev" + +# do cmn on the sub-segment or on the vad segment +subseg_cmn=true +# whether print the evaluation result for each file +get_each_file_res=1 + +. tools/parse_options.sh + +# Prerequisite +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + mkdir -p external_tools + + # [1] Download evaluation toolkit + wget -c https://github.com/usnistgov/SCTK/archive/refs/tags/v2.4.12.zip -O external_tools/SCTK-v2.4.12.zip + unzip -o external_tools/SCTK-v2.4.12.zip -d external_tools + + # [3] Download ResNet34 speaker model pretrained by WeSpeaker Team + mkdir -p pretrained_models + + wget -c https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx -O pretrained_models/voxceleb_resnet34_LM.onnx +fi + + +# Download VoxConverse dev/test audios and the corresponding annotations +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + mkdir -p data + + # Download annotations for dev and test sets (version 0.0.3) + wget -c https://github.com/joonson/voxconverse/archive/refs/heads/master.zip -O data/voxconverse_master.zip + unzip -o data/voxconverse_master.zip -d data + + # Download annotations from VoxSRC-23 validation toolkit (looks like version 0.0.2) + # cd data && git clone https://github.com/JaesungHuh/VoxSRC2023.git --recursive && cd - + + # Download dev audios + mkdir -p data/dev + + #wget --no-check-certificate -c https://mm.kaist.ac.kr/datasets/voxconverse/data/voxconverse_dev_wav.zip -O data/voxconverse_dev_wav.zip + # The above url may not be reachable, you can try the link below. + # This url is from https://github.com/joonson/voxconverse/blob/master/README.md + wget --no-check-certificate -c https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_dev_wav.zip -O data/voxconverse_dev_wav.zip + unzip -o data/voxconverse_dev_wav.zip -d data/dev + + # Create wav.scp for dev audios + ls `pwd`/data/dev/audio/*.wav | awk -F/ '{print substr($NF, 1, length($NF)-4), $0}' > data/dev/wav.scp + + # Test audios + mkdir -p data/test + + #wget --no-check-certificate -c https://mm.kaist.ac.kr/datasets/voxconverse/data/voxconverse_test_wav.zip -O data/voxconverse_test_wav.zip + # The above url may not be reachable, you can try the link below. + # This url is from https://github.com/joonson/voxconverse/blob/master/README.md + wget --no-check-certificate -c https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_test_wav.zip -O data/voxconverse_test_wav.zip + unzip -o data/voxconverse_test_wav.zip -d data/test + + # Create wav.scp for test audios + ls `pwd`/data/test/voxconverse_test_wav/*.wav | awk -F/ '{print substr($NF, 1, length($NF)-4), $0}' > data/test/wav.scp +fi + + +# Voice activity detection +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # Set VAD min duration + min_duration=0.255 + + if [[ "x${sad_type}" == "xoracle" ]]; then + # Oracle SAD: handling overlapping or too short regions in ground truth RTTM + while read -r utt wav_path; do + python3 wespeaker/diar/make_oracle_sad.py \ + --rttm data/voxconverse-master/${partition}/${utt}.rttm \ + --min-duration $min_duration + done < data/${partition}/wav.scp > data/${partition}/oracle_sad + fi + + if [[ "x${sad_type}" == "xsystem" ]]; then + # System SAD: applying 'silero' VAD + python3 wespeaker/diar/make_system_sad.py \ + --scp data/${partition}/wav.scp \ + --min-duration $min_duration > data/${partition}/system_sad + fi +fi + + +# Extract fbank features +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + + [ -d "exp/${sad_type}_sad_fbank" ] && rm -r exp/${sad_type}_sad_fbank + + echo "Make Fbank features and store it under exp/${sad_type}_sad_fbank" + echo "..." + bash local/make_fbank.sh \ + --scp data/${partition}/wav.scp \ + --segments data/${partition}/${sad_type}_sad \ + --store_dir exp/${partition}_${sad_type}_sad_fbank \ + --subseg_cmn ${subseg_cmn} \ + --nj 24 +fi + +# Extract embeddings +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + + [ -d "exp/${sad_type}_sad_embedding" ] && rm -r exp/${sad_type}_sad_embedding + + echo "Extract embeddings and store it under exp/${sad_type}_sad_embedding" + echo "..." + bash local/extract_emb.sh \ + --scp exp/${partition}_${sad_type}_sad_fbank/fbank.scp \ + --pretrained_model pretrained_models/voxceleb_resnet34_LM.onnx \ + --device cuda \ + --store_dir exp/${partition}_${sad_type}_sad_embedding \ + --batch_size 96 \ + --frame_shift 10 \ + --window_secs 1.5 \ + --period_secs 0.75 \ + --subseg_cmn ${subseg_cmn} \ + --nj 1 +fi + + +# Applying umap clustering algorithm +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + + [ -f "exp/umap_cluster/${partition}_${sad_type}_sad_labels" ] && rm exp/umap_cluster/${partition}_${sad_type}_sad_labels + + echo "Doing umap clustering and store the result in exp/umap_cluster/${partition}_${sad_type}_sad_labels" + echo "..." + python3 wespeaker/diar/umap_clusterer.py \ + --scp exp/${partition}_${sad_type}_sad_embedding/emb.scp \ + --output exp/umap_cluster/${partition}_${sad_type}_sad_labels +fi + + +# Convert labels to RTTMs +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + python3 wespeaker/diar/make_rttm.py \ + --labels exp/umap_cluster/${partition}_${sad_type}_sad_labels \ + --channel 1 > exp/umap_cluster/${partition}_${sad_type}_sad_rttm +fi + + +# Evaluate the result +if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + ref_dir=data/voxconverse-master/ + #ref_dir=data/VoxSRC2023/voxconverse/ + echo -e "Get the DER results\n..." + perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \ + -c 0.25 \ + -r <(cat ${ref_dir}/${partition}/*.rttm) \ + -s exp/umap_cluster/${partition}_${sad_type}_sad_rttm 2>&1 | tee exp/umap_cluster/${partition}_${sad_type}_sad_res + + if [ ${get_each_file_res} -eq 1 ];then + single_file_res_dir=exp/umap_cluster/${partition}_${sad_type}_single_file_res + mkdir -p $single_file_res_dir + echo -e "\nGet the DER results for each file and the results will be stored underd ${single_file_res_dir}\n..." + + awk '{print $2}' exp/umap_cluster/${partition}_${sad_type}_sad_rttm | sort -u | while read file_name; do + perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \ + -c 0.25 \ + -r <(cat ${ref_dir}/${partition}/${file_name}.rttm) \ + -s <(grep "${file_name}" exp/umap_cluster/${partition}_${sad_type}_sad_rttm) > ${single_file_res_dir}/${partition}_${file_name}_res + done + echo "Done!" + fi +fi diff --git a/examples/voxconverse/v3/tools b/examples/voxconverse/v3/tools new file mode 120000 index 00000000..c92f4172 --- /dev/null +++ b/examples/voxconverse/v3/tools @@ -0,0 +1 @@ +../../../tools \ No newline at end of file diff --git a/examples/voxconverse/v3/wespeaker b/examples/voxconverse/v3/wespeaker new file mode 120000 index 00000000..900c560b --- /dev/null +++ b/examples/voxconverse/v3/wespeaker @@ -0,0 +1 @@ +../../../wespeaker \ No newline at end of file diff --git a/wespeaker/diar/pahc.py b/wespeaker/diar/pahc.py new file mode 100644 index 00000000..4b5c6460 --- /dev/null +++ b/wespeaker/diar/pahc.py @@ -0,0 +1,165 @@ +# Copyright (c) 2023 Xu Xiang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import heapq +from collections import defaultdict + +import numpy as np + + +def l2norm(x, axis=0, keepdims=True): + return x / np.linalg.norm(x, axis=axis, keepdims=keepdims) + + +class PAHC: + def __init__(self, merge_cutoff=0.3, min_cluster_size=3, absorb_cutoff=0.0): + self.merge_cutoff = merge_cutoff + self.min_cluster_size = min_cluster_size + self.absorb_cutoff = absorb_cutoff + + def fit_predict(self, labels, embeddings): + self.initialize(labels, embeddings) + self.merge_cluster() + self.absorb_cluster() + labels = self.relabel_cluster() + return labels + + def initialize(self, labels, embeddings): + self.labels = labels + self.embeddings = embeddings + self.active_clusters = set([]) + self.label_map = {} + self.cost_map = {} + self.heap = [] + self.next_index = -1 + + self.build_label_map() + self.build_cost_map() + + def merge_cluster(self): + while self.heap: + _, (i, j) = heapq.heappop(self.heap) + if i in self.active_clusters and j in self.active_clusters: + self.merge(i, j) + + def absorb_cluster(self): + minor_clusters = set() + major_clusters = set() + for k, indexes in self.label_map.items(): + if len(indexes) < self.min_cluster_size: + minor_clusters.add(k) + else: + major_clusters.add(k) + + if len(major_clusters) > 0: + for i in minor_clusters: + max_cost = -np.inf + for j in major_clusters: + pair = (i, j) if i < j else (j, i) + i_indexes, j_indexes = self.label_map[i], self.label_map[j] + factor = len(i_indexes) * len(j_indexes) + normalized_cost = self.cost_map[pair] / factor + if normalized_cost > max_cost: + max_cost = normalized_cost + closest_cluster = j + if max_cost >= self.absorb_cutoff: + self.label_map[closest_cluster].extend(self.label_map[i]) + self.eliminate(i) + + def relabel_cluster(self): + labels = [-1] * len(self.labels) + + for label, indexes in self.label_map.items(): + for index in indexes: + labels[index] = label + i = 0 + label_to_label = {} + for label in labels: + if label not in label_to_label: + label_to_label[label] = i + i += 1 + for i in range(len(labels)): + labels[i] = label_to_label[labels[i]] + return labels + + def eliminate(self, i): + del self.label_map[i] + self.active_clusters.remove(i) + + def build_label_map(self): + self.label_map = defaultdict(list) + + for i, label in enumerate(self.labels): + self.label_map[label].append(i) + + self.num_labeled = len(self.label_map) + + if -1 in self.label_map: + self.num_labeled -= 1 + for i, j in zip(range(self.num_labeled, + self.num_labeled + len(self.label_map[-1])), + self.label_map[-1]): + self.label_map[i].append(j) + del self.label_map[-1] + + def build_cost_map(self): + N = len(self.label_map) + self.active_clusters = set(range(N)) + self.next_index = N + + for i in range(N): + for j in range(i + 1, N): + i_indexes, j_indexes = self.label_map[i], self.label_map[j] + + if i < self.num_labeled and j < self.num_labeled: + self.cost_map[(i, j)] = -np.inf + continue + + self.cost_map[(i, j)] = self.compute_cost(i_indexes, j_indexes) + + factor = len(i_indexes) * len(j_indexes) + normalized_cost = self.cost_map[(i, j)] / factor + if normalized_cost >= self.merge_cutoff: + heapq.heappush(self.heap, (-normalized_cost, (i, j))) + + def compute_cost(self, i_indexes, j_indexes): + i_embedding = sum([ + l2norm(self.embeddings[i_index]) for i_index in i_indexes]) + j_embedding = sum([ + l2norm(self.embeddings[j_index]) for j_index in j_indexes]) + return np.dot(i_embedding, j_embedding) + + def merge(self, i, j): + i_indexes, j_indexes = self.label_map[i], self.label_map[j] + + for k, _ in self.label_map.items(): + if k == i or k == j: + continue + pair1 = (k, i) if k < i else (i, k) + pair2 = (k, j) if k < j else (j, k) + cost = self.cost_map[pair1] + self.cost_map[pair2] + self.cost_map[(k, self.next_index)] = cost + + factor = (len(i_indexes) + len(j_indexes)) * len(self.label_map[k]) + normalized_cost = cost / factor + if normalized_cost >= self.merge_cutoff: + heapq.heappush(self.heap, (-normalized_cost, + (k, self.next_index))) + + self.label_map[self.next_index] = i_indexes + j_indexes + self.active_clusters.add(self.next_index) + self.eliminate(i) + self.eliminate(j) + self.next_index += 1 diff --git a/wespeaker/diar/umap_clusterer.py b/wespeaker/diar/umap_clusterer.py new file mode 100644 index 00000000..aae01368 --- /dev/null +++ b/wespeaker/diar/umap_clusterer.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023 Xu Xiang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["VECLIB_MAXIMUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["NUMBA_NUM_THREADS"] = "1" + +import argparse +import concurrent.futures +from collections import OrderedDict + +import numpy as np + +import kaldiio +import umap +import hdbscan +import pahc + + +def get_args(): + parser = argparse.ArgumentParser(description='') + parser.add_argument('--scp', required=True, help='embedding scp') + parser.add_argument('--output', required=True, help='output label file') + parser.add_argument('--n_neighbors', required=False, default=16, + help="The size of the local neighborhood UMAP " + "will look at when attempting to learn " + "the manifold structure of the data. " + "This means that low values of n_neighbors " + "will force UMAP to concentrate on " + "very local structure (potentially to " + "the detriment of the big picture), " + "while large values will push UMAP to " + "look at larger neighborhoods of each point " + "when estimating the manifold structure of " + "the data, losing fine detail structure for " + "the sake of getting the broader of the data.") + parser.add_argument('--min_dist', required=False, default=0.1, + help="The minimum distance between points in " + "the low dimensional representation.") + args = parser.parse_args() + return args + + +def read_emb(scp): + emb_dict = OrderedDict() + for sub_seg_id, emb in kaldiio.load_scp_sequential(scp): + utt = sub_seg_id.split('-')[0] + if utt not in emb_dict: + emb_dict[utt] = {} + emb_dict[utt]['sub_seg'] = [] + emb_dict[utt]['embs'] = [] + emb_dict[utt]['sub_seg'].append(sub_seg_id) + emb_dict[utt]['embs'].append(emb) + subsegs_list = [] + embeddings_list = [] + for utt, utt_emb_dict in emb_dict.items(): + subsegs_list.append(utt_emb_dict['sub_seg']) + embeddings_list.append(np.stack(utt_emb_dict['embs'])) + return subsegs_list, embeddings_list + + +def cluster(embeddings): + # Fallback + if len(embeddings) <= 2: + return [0] * len(embeddings) + + n_neighbors, min_dist = int(args.n_neighbors), float(args.min_dist) + + umap_embeddings = umap.UMAP(n_components=min(32, len(embeddings) - 2), + metric='cosine', + n_neighbors=n_neighbors, + min_dist=min_dist, + random_state=2020, + n_jobs=1).fit_transform(np.array(embeddings)) + + labels = hdbscan.HDBSCAN(core_dist_n_jobs=1, + allow_single_cluster=True, + min_cluster_size=4).fit_predict(umap_embeddings) + + labels = pahc.PAHC(merge_cutoff=0.3, + min_cluster_size=3, + absorb_cutoff=0.0).fit_predict(labels, embeddings) + return labels + + +if __name__ == '__main__': + args = get_args() + + subsegs_list, embeddings_list = read_emb(args.scp) + + with concurrent.futures.ProcessPoolExecutor() as executor: + with open(args.output, 'w') as fd: + for (subsegs, labels) in zip(subsegs_list, + executor.map(cluster, + embeddings_list)): + [print(subseg, + label, + file=fd) for (subseg, label) in zip(subsegs, labels)] From fdcf72aecab2b046cdd80974c348277474aaf0ce Mon Sep 17 00:00:00 2001 From: xx205 Date: Sun, 11 Aug 2024 19:50:28 +0000 Subject: [PATCH 02/12] resolve pylint issues and add missing modifications --- requirements.txt | 2 ++ wespeaker/cli/speaker.py | 21 +++------------- wespeaker/diar/make_system_sad.py | 40 +++++++++---------------------- wespeaker/diar/pahc.py | 8 +++---- 4 files changed, 20 insertions(+), 51 deletions(-) diff --git a/requirements.txt b/requirements.txt index 535dc650..455fb1e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,5 @@ soundfile==0.10.3.post1 pypeln==0.4.9 silero-vad pre-commit==3.5.0 +hdbscan==0.8.37 +umap-learn==0.5.6 diff --git a/wespeaker/cli/speaker.py b/wespeaker/cli/speaker.py index 22879e9b..b2806d07 100644 --- a/wespeaker/cli/speaker.py +++ b/wespeaker/cli/speaker.py @@ -29,7 +29,7 @@ from wespeaker.cli.utils import get_args from wespeaker.models.speaker_model import get_speaker_model from wespeaker.utils.checkpoint import load_checkpoint -from wespeaker.diar.spectral_clusterer import cluster +from wespeaker.diar.umap_clusterer import cluster from wespeaker.diar.extract_emb import subsegment from wespeaker.diar.make_rttm import merge_segments from wespeaker.utils.utils import set_seed @@ -55,9 +55,6 @@ def __init__(self, model_dir: str): self.wavform_norm = False # diarization parmas - self.diar_num_spks = None - self.diar_min_num_spks = 1 - self.diar_max_num_spks = 20 self.diar_min_duration = 0.255 self.diar_window_secs = 1.5 self.diar_period_secs = 0.75 @@ -83,18 +80,12 @@ def set_gpu(self, device_id: int): self.model = self.model.to(self.device) def set_diarization_params(self, - num_spks=None, - min_num_spks=1, - max_num_spks=20, min_duration: float = 0.255, window_secs: float = 1.5, period_secs: float = 0.75, frame_shift: int = 10, batch_size: int = 32, subseg_cmn: bool = True): - self.diar_num_spks = num_spks - self.diar_min_num_spks = min_num_spks - self.diar_max_num_spks = max_num_spks self.diar_min_duration = min_duration self.diar_window_secs = window_secs self.diar_period_secs = period_secs @@ -251,10 +242,7 @@ def diarize(self, audio_path: str, utt: str = "unk"): # 4. cluster subseg2label = [] - labels = cluster(embeddings, - num_spks=self.diar_num_spks, - min_num_spks=self.diar_min_num_spks, - max_num_spks=self.diar_max_num_spks) + labels = cluster(embeddings) for (_subseg, _label) in zip(subsegs, labels): # b, e = process_seg_id(_subseg, frame_shift=self.diar_frame_shift) # subseg2label.append([b, e, _label]) @@ -316,10 +304,7 @@ def main(): model.set_resample_rate(args.resample_rate) model.set_vad(args.vad) model.set_gpu(args.gpu) - model.set_diarization_params(num_spks=args.diar_num_spks, - min_num_spks=args.diar_min_num_spks, - max_num_spks=args.diar_max_num_spks, - min_duration=args.diar_min_duration, + model.set_diarization_params(min_duration=args.diar_min_duration, window_secs=args.diar_window_secs, period_secs=args.diar_period_secs, frame_shift=args.diar_frame_shift, diff --git a/wespeaker/diar/make_system_sad.py b/wespeaker/diar/make_system_sad.py index c4ab13ec..a4a98911 100644 --- a/wespeaker/diar/make_system_sad.py +++ b/wespeaker/diar/make_system_sad.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 Xu Xiang +# Copyright (c) 2022-2024 Xu Xiang # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,21 +20,17 @@ os.environ["VECLIB_MAXIMUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" -import sys import functools import concurrent.futures import argparse -import importlib import torch +import silero_vad from wespeaker.utils.file_utils import read_scp def get_args(): parser = argparse.ArgumentParser(description='') - parser.add_argument('--repo-path', - required=True, - help='VAD model repo path') parser.add_argument('--scp', required=True, help='wav scp') parser.add_argument('--min-duration', required=True, @@ -45,28 +41,16 @@ def get_args(): return args -def silero_vad(utt_wav_pair, - repo_path, - min_duration, - sampling_rate=16000, - threshold=0.25): - - def module_from_file(module_name, file_path): - spec = importlib.util.spec_from_file_location(module_name, file_path) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - utils_vad = module_from_file("utils_vad", - os.path.join(repo_path, "utils_vad.py")) - model = utils_vad.init_jit_model( - os.path.join(repo_path, 'files/silero_vad.jit')) +def vad(utt_wav_pair, + min_duration, + sampling_rate=16000, + threshold=0.18): + model = silero_vad.load_silero_vad() utt, wav = utt_wav_pair - wav = utils_vad.read_audio(wav, sampling_rate=sampling_rate) - speech_timestamps = utils_vad.get_speech_timestamps( + wav = silero_vad.read_audio(wav, sampling_rate=sampling_rate) + speech_timestamps = silero_vad.get_speech_timestamps( wav, model, sampling_rate=sampling_rate, threshold=threshold) vad_result = "" @@ -83,13 +67,11 @@ def module_from_file(module_name, file_path): def main(): args = get_args() - vad = functools.partial(silero_vad, - repo_path=args.repo_path, - min_duration=args.min_duration) + run_vad = functools.partial(vad, min_duration=args.min_duration) utt_wav_pair_list = read_scp(args.scp) with concurrent.futures.ProcessPoolExecutor() as executor: - print(''.join(executor.map(vad, utt_wav_pair_list)), end='') + print(''.join(executor.map(run_vad, utt_wav_pair_list)), end='') if __name__ == '__main__': diff --git a/wespeaker/diar/pahc.py b/wespeaker/diar/pahc.py index 4b5c6460..19c4d972 100644 --- a/wespeaker/diar/pahc.py +++ b/wespeaker/diar/pahc.py @@ -44,7 +44,7 @@ def initialize(self, labels, embeddings): self.cost_map = {} self.heap = [] self.next_index = -1 - + self.build_label_map() self.build_cost_map() @@ -93,11 +93,11 @@ def relabel_cluster(self): for i in range(len(labels)): labels[i] = label_to_label[labels[i]] return labels - + def eliminate(self, i): del self.label_map[i] self.active_clusters.remove(i) - + def build_label_map(self): self.label_map = defaultdict(list) @@ -126,7 +126,7 @@ def build_cost_map(self): if i < self.num_labeled and j < self.num_labeled: self.cost_map[(i, j)] = -np.inf continue - + self.cost_map[(i, j)] = self.compute_cost(i_indexes, j_indexes) factor = len(i_indexes) * len(j_indexes) From 700dfe0290658616521334b01b73b16972cdf02a Mon Sep 17 00:00:00 2001 From: xx205 Date: Sun, 11 Aug 2024 20:05:04 +0000 Subject: [PATCH 03/12] eliminate trailing whitespace --- wespeaker/diar/pahc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wespeaker/diar/pahc.py b/wespeaker/diar/pahc.py index 19c4d972..3e4f2d94 100644 --- a/wespeaker/diar/pahc.py +++ b/wespeaker/diar/pahc.py @@ -108,7 +108,7 @@ def build_label_map(self): if -1 in self.label_map: self.num_labeled -= 1 - for i, j in zip(range(self.num_labeled, + for i, j in zip(range(self.num_labeled, self.num_labeled + len(self.label_map[-1])), self.label_map[-1]): self.label_map[i].append(j) @@ -132,7 +132,7 @@ def build_cost_map(self): factor = len(i_indexes) * len(j_indexes) normalized_cost = self.cost_map[(i, j)] / factor if normalized_cost >= self.merge_cutoff: - heapq.heappush(self.heap, (-normalized_cost, (i, j))) + heapq.heappush(self.heap, (-normalized_cost, (i, j))) def compute_cost(self, i_indexes, j_indexes): i_embedding = sum([ @@ -155,7 +155,7 @@ def merge(self, i, j): factor = (len(i_indexes) + len(j_indexes)) * len(self.label_map[k]) normalized_cost = cost / factor if normalized_cost >= self.merge_cutoff: - heapq.heappush(self.heap, (-normalized_cost, + heapq.heappush(self.heap, (-normalized_cost, (k, self.next_index))) self.label_map[self.next_index] = i_indexes + j_indexes From 77c340ddb826a9ce4d3aa73aa385292bf7828ed8 Mon Sep 17 00:00:00 2001 From: xx205 Date: Mon, 12 Aug 2024 02:39:46 +0000 Subject: [PATCH 04/12] deterministic clustering; update README.md --- examples/voxconverse/v3/README.md | 7 ++++--- wespeaker/diar/umap_clusterer.py | 13 ++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/voxconverse/v3/README.md b/examples/voxconverse/v3/README.md index 02f41fa1..c86e6e34 100644 --- a/examples/voxconverse/v3/README.md +++ b/examples/voxconverse/v3/README.md @@ -15,8 +15,8 @@ | system | MISS | FA | SC | DER | |:---|:---:|:---:|:---:|:---:| - | This repo (with oracle SAD) | 2.3 | 0.0 | 2.1 | 4.4 | - | This repo (with system SAD) | 3.7 | 0.8 | 2.2 | 6.8 | + | This repo (with oracle SAD) | 2.3 | 0.0 | 1.3 | 3.6 | + | This repo (with system SAD) | 3.4 | 0.6 | 1.4 | 5.4 | | DIHARD 2019 baseline [^1] | 11.1 | 1.4 | 11.3 | 23.8 | | DIHARD 2019 baseline w/ SE [^1] | 9.3 | 1.3 | 9.7 | 20.2 | | (SyncNet ASD only) [^1] | 2.2 | 4.1 | 4.0 | 10.4 | @@ -27,7 +27,8 @@ | system | MISS | FA | SC | DER | |:---|:---:|:---:|:---:|:---:| - | This repo (with system SAD) | 4.0 | 2.4 | 3.4 | 9.8 | + | This repo (with oracle SAD) | 1.6 | 0.0 | 1.9 | 3.5 | + | This repo (with system SAD) | 3.8 | 1.7 | 1.8 | 7.4 | [^1]: Spot the conversation: speaker diarisation in the wild, https://arxiv.org/pdf/2007.01216.pdf diff --git a/wespeaker/diar/umap_clusterer.py b/wespeaker/diar/umap_clusterer.py index aae01368..26ae2a1e 100644 --- a/wespeaker/diar/umap_clusterer.py +++ b/wespeaker/diar/umap_clusterer.py @@ -50,7 +50,7 @@ def get_args(): "when estimating the manifold structure of " "the data, losing fine detail structure for " "the sake of getting the broader of the data.") - parser.add_argument('--min_dist', required=False, default=0.1, + parser.add_argument('--min_dist', required=False, default=0.05, help="The minimum distance between points in " "the low dimensional representation.") args = parser.parse_args() @@ -86,12 +86,13 @@ def cluster(embeddings): metric='cosine', n_neighbors=n_neighbors, min_dist=min_dist, - random_state=2020, + random_state=2023, n_jobs=1).fit_transform(np.array(embeddings)) - labels = hdbscan.HDBSCAN(core_dist_n_jobs=1, - allow_single_cluster=True, - min_cluster_size=4).fit_predict(umap_embeddings) + labels = hdbscan.HDBSCAN(allow_single_cluster=True, + min_cluster_size=4, + approx_min_span_tree=False, + core_dist_n_jobs=1).fit_predict(umap_embeddings) labels = pahc.PAHC(merge_cutoff=0.3, min_cluster_size=3, @@ -104,6 +105,8 @@ def cluster(embeddings): subsegs_list, embeddings_list = read_emb(args.scp) + os.makedirs(os.path.dirname(args.output), exist_ok=True) + with concurrent.futures.ProcessPoolExecutor() as executor: with open(args.output, 'w') as fd: for (subsegs, labels) in zip(subsegs_list, From 7636e32244d01b97bec28277ceb26861a1579461 Mon Sep 17 00:00:00 2001 From: xx205 Date: Mon, 12 Aug 2024 04:49:07 +0000 Subject: [PATCH 05/12] fix args usage in umap_clusterer.py --- wespeaker/diar/umap_clusterer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/wespeaker/diar/umap_clusterer.py b/wespeaker/diar/umap_clusterer.py index 26ae2a1e..35da6f8c 100644 --- a/wespeaker/diar/umap_clusterer.py +++ b/wespeaker/diar/umap_clusterer.py @@ -24,6 +24,7 @@ import argparse import concurrent.futures from collections import OrderedDict +import functools import numpy as np @@ -75,13 +76,11 @@ def read_emb(scp): return subsegs_list, embeddings_list -def cluster(embeddings): +def cluster(embeddings, n_neighbors=16, min_dist=0.05): # Fallback if len(embeddings) <= 2: return [0] * len(embeddings) - n_neighbors, min_dist = int(args.n_neighbors), float(args.min_dist) - umap_embeddings = umap.UMAP(n_components=min(32, len(embeddings) - 2), metric='cosine', n_neighbors=n_neighbors, @@ -107,10 +106,16 @@ def cluster(embeddings): os.makedirs(os.path.dirname(args.output), exist_ok=True) + n_neighbors, min_dist = int(args.n_neighbors), float(args.min_dist) + + run_cluster = functools.partial(cluster, + n_neighbors=n_neighbors, + min_dist=min_dist) + with concurrent.futures.ProcessPoolExecutor() as executor: with open(args.output, 'w') as fd: for (subsegs, labels) in zip(subsegs_list, - executor.map(cluster, + executor.map(run_cluster, embeddings_list)): [print(subseg, label, From 2731690ef7f1ffce3eb64578cfe5ce3d2ab622d2 Mon Sep 17 00:00:00 2001 From: xx205 Date: Mon, 12 Aug 2024 15:55:01 +0000 Subject: [PATCH 06/12] local import; remove unused diarization args; self.model.eval() when init --- wespeaker/cli/speaker.py | 11 ++++++----- wespeaker/cli/utils.py | 12 ------------ wespeaker/diar/umap_clusterer.py | 2 +- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/wespeaker/cli/speaker.py b/wespeaker/cli/speaker.py index b2806d07..be064ac1 100644 --- a/wespeaker/cli/speaker.py +++ b/wespeaker/cli/speaker.py @@ -47,6 +47,7 @@ def __init__(self, model_dir: str): self.model = get_speaker_model( configs['model'])(**configs['model_args']) load_checkpoint(self.model, model_path) + self.model.eval() self.vad = load_silero_vad() self.table = {} self.resample_rate = 16000 @@ -118,10 +119,10 @@ def extract_embedding_feats(self, fbanks, batch_size, subseg_cmn): fbanks_array = torch.from_numpy(fbanks_array).to(self.device) for i in tqdm(range(0, fbanks_array.shape[0], batch_size)): batch_feats = fbanks_array[i:i + batch_size] - # _, batch_embs = self.model(batch_feats) - batch_embs = self.model(batch_feats) - batch_embs = batch_embs[-1] if isinstance(batch_embs, - tuple) else batch_embs + with torch.no_grad(): + batch_embs = self.model(batch_feats) + batch_embs = batch_embs[-1] if isinstance(batch_embs, + tuple) else batch_embs embeddings.append(batch_embs.detach().cpu().numpy()) embeddings = np.vstack(embeddings) return embeddings @@ -153,7 +154,7 @@ def extract_embedding(self, audio_path: str): cmn=True) feats = feats.unsqueeze(0) feats = feats.to(self.device) - self.model.eval() + with torch.no_grad(): outputs = self.model(feats) outputs = outputs[-1] if isinstance(outputs, tuple) else outputs diff --git a/wespeaker/cli/utils.py b/wespeaker/cli/utils.py index 8289114b..3d5125f2 100644 --- a/wespeaker/cli/utils.py +++ b/wespeaker/cli/utils.py @@ -75,18 +75,6 @@ def get_args(): help='output file to save speaker embedding ' 'or save diarization result') # diarization params - parser.add_argument('--diar_num_spks', - type=int, - default=None, - help='number of speakers') - parser.add_argument('--diar_min_num_spks', - type=int, - default=1, - help='minimum number of speakers') - parser.add_argument('--diar_max_num_spks', - type=int, - default=20, - help='maximum number of speakers') parser.add_argument('--diar_min_duration', type=float, default=0.255, diff --git a/wespeaker/diar/umap_clusterer.py b/wespeaker/diar/umap_clusterer.py index 35da6f8c..e7d673a5 100644 --- a/wespeaker/diar/umap_clusterer.py +++ b/wespeaker/diar/umap_clusterer.py @@ -31,7 +31,7 @@ import kaldiio import umap import hdbscan -import pahc +from . import pahc def get_args(): From 4ac134dca6c32180bea6372bb16cf881296c6dac Mon Sep 17 00:00:00 2001 From: xx205 Date: Mon, 12 Aug 2024 16:18:16 +0000 Subject: [PATCH 07/12] compact embedding clustering procedure into a single source file --- wespeaker/diar/pahc.py | 165 ------------------------------- wespeaker/diar/umap_clusterer.py | 155 ++++++++++++++++++++++++++++- 2 files changed, 150 insertions(+), 170 deletions(-) delete mode 100644 wespeaker/diar/pahc.py diff --git a/wespeaker/diar/pahc.py b/wespeaker/diar/pahc.py deleted file mode 100644 index 3e4f2d94..00000000 --- a/wespeaker/diar/pahc.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright (c) 2023 Xu Xiang -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import heapq -from collections import defaultdict - -import numpy as np - - -def l2norm(x, axis=0, keepdims=True): - return x / np.linalg.norm(x, axis=axis, keepdims=keepdims) - - -class PAHC: - def __init__(self, merge_cutoff=0.3, min_cluster_size=3, absorb_cutoff=0.0): - self.merge_cutoff = merge_cutoff - self.min_cluster_size = min_cluster_size - self.absorb_cutoff = absorb_cutoff - - def fit_predict(self, labels, embeddings): - self.initialize(labels, embeddings) - self.merge_cluster() - self.absorb_cluster() - labels = self.relabel_cluster() - return labels - - def initialize(self, labels, embeddings): - self.labels = labels - self.embeddings = embeddings - self.active_clusters = set([]) - self.label_map = {} - self.cost_map = {} - self.heap = [] - self.next_index = -1 - - self.build_label_map() - self.build_cost_map() - - def merge_cluster(self): - while self.heap: - _, (i, j) = heapq.heappop(self.heap) - if i in self.active_clusters and j in self.active_clusters: - self.merge(i, j) - - def absorb_cluster(self): - minor_clusters = set() - major_clusters = set() - for k, indexes in self.label_map.items(): - if len(indexes) < self.min_cluster_size: - minor_clusters.add(k) - else: - major_clusters.add(k) - - if len(major_clusters) > 0: - for i in minor_clusters: - max_cost = -np.inf - for j in major_clusters: - pair = (i, j) if i < j else (j, i) - i_indexes, j_indexes = self.label_map[i], self.label_map[j] - factor = len(i_indexes) * len(j_indexes) - normalized_cost = self.cost_map[pair] / factor - if normalized_cost > max_cost: - max_cost = normalized_cost - closest_cluster = j - if max_cost >= self.absorb_cutoff: - self.label_map[closest_cluster].extend(self.label_map[i]) - self.eliminate(i) - - def relabel_cluster(self): - labels = [-1] * len(self.labels) - - for label, indexes in self.label_map.items(): - for index in indexes: - labels[index] = label - i = 0 - label_to_label = {} - for label in labels: - if label not in label_to_label: - label_to_label[label] = i - i += 1 - for i in range(len(labels)): - labels[i] = label_to_label[labels[i]] - return labels - - def eliminate(self, i): - del self.label_map[i] - self.active_clusters.remove(i) - - def build_label_map(self): - self.label_map = defaultdict(list) - - for i, label in enumerate(self.labels): - self.label_map[label].append(i) - - self.num_labeled = len(self.label_map) - - if -1 in self.label_map: - self.num_labeled -= 1 - for i, j in zip(range(self.num_labeled, - self.num_labeled + len(self.label_map[-1])), - self.label_map[-1]): - self.label_map[i].append(j) - del self.label_map[-1] - - def build_cost_map(self): - N = len(self.label_map) - self.active_clusters = set(range(N)) - self.next_index = N - - for i in range(N): - for j in range(i + 1, N): - i_indexes, j_indexes = self.label_map[i], self.label_map[j] - - if i < self.num_labeled and j < self.num_labeled: - self.cost_map[(i, j)] = -np.inf - continue - - self.cost_map[(i, j)] = self.compute_cost(i_indexes, j_indexes) - - factor = len(i_indexes) * len(j_indexes) - normalized_cost = self.cost_map[(i, j)] / factor - if normalized_cost >= self.merge_cutoff: - heapq.heappush(self.heap, (-normalized_cost, (i, j))) - - def compute_cost(self, i_indexes, j_indexes): - i_embedding = sum([ - l2norm(self.embeddings[i_index]) for i_index in i_indexes]) - j_embedding = sum([ - l2norm(self.embeddings[j_index]) for j_index in j_indexes]) - return np.dot(i_embedding, j_embedding) - - def merge(self, i, j): - i_indexes, j_indexes = self.label_map[i], self.label_map[j] - - for k, _ in self.label_map.items(): - if k == i or k == j: - continue - pair1 = (k, i) if k < i else (i, k) - pair2 = (k, j) if k < j else (j, k) - cost = self.cost_map[pair1] + self.cost_map[pair2] - self.cost_map[(k, self.next_index)] = cost - - factor = (len(i_indexes) + len(j_indexes)) * len(self.label_map[k]) - normalized_cost = cost / factor - if normalized_cost >= self.merge_cutoff: - heapq.heappush(self.heap, (-normalized_cost, - (k, self.next_index))) - - self.label_map[self.next_index] = i_indexes + j_indexes - self.active_clusters.add(self.next_index) - self.eliminate(i) - self.eliminate(j) - self.next_index += 1 diff --git a/wespeaker/diar/umap_clusterer.py b/wespeaker/diar/umap_clusterer.py index e7d673a5..23b38ead 100644 --- a/wespeaker/diar/umap_clusterer.py +++ b/wespeaker/diar/umap_clusterer.py @@ -23,15 +23,160 @@ import argparse import concurrent.futures -from collections import OrderedDict +from collections import OrderedDict, defaultdict import functools +import heapq import numpy as np import kaldiio import umap import hdbscan -from . import pahc + + +class PAHC: + def __init__(self, merge_cutoff=0.3, min_cluster_size=3, absorb_cutoff=0.0): + self.merge_cutoff = merge_cutoff + self.min_cluster_size = min_cluster_size + self.absorb_cutoff = absorb_cutoff + + def fit_predict(self, labels, embeddings): + self.initialize(labels, embeddings) + self.merge_cluster() + self.absorb_cluster() + labels = self.relabel_cluster() + return labels + + def initialize(self, labels, embeddings): + self.labels = labels + self.embeddings = embeddings + self.active_clusters = set([]) + self.label_map = {} + self.cost_map = {} + self.heap = [] + self.next_index = -1 + + self.build_label_map() + self.build_cost_map() + + def merge_cluster(self): + while self.heap: + _, (i, j) = heapq.heappop(self.heap) + if i in self.active_clusters and j in self.active_clusters: + self.merge(i, j) + + def absorb_cluster(self): + minor_clusters = set() + major_clusters = set() + for k, indexes in self.label_map.items(): + if len(indexes) < self.min_cluster_size: + minor_clusters.add(k) + else: + major_clusters.add(k) + + if len(major_clusters) > 0: + for i in minor_clusters: + max_cost = -np.inf + for j in major_clusters: + pair = (i, j) if i < j else (j, i) + i_indexes, j_indexes = self.label_map[i], self.label_map[j] + factor = len(i_indexes) * len(j_indexes) + normalized_cost = self.cost_map[pair] / factor + if normalized_cost > max_cost: + max_cost = normalized_cost + closest_cluster = j + if max_cost >= self.absorb_cutoff: + self.label_map[closest_cluster].extend(self.label_map[i]) + self.eliminate(i) + + def relabel_cluster(self): + labels = [-1] * len(self.labels) + + for label, indexes in self.label_map.items(): + for index in indexes: + labels[index] = label + i = 0 + label_to_label = {} + for label in labels: + if label not in label_to_label: + label_to_label[label] = i + i += 1 + for i in range(len(labels)): + labels[i] = label_to_label[labels[i]] + return labels + + def eliminate(self, i): + del self.label_map[i] + self.active_clusters.remove(i) + + def build_label_map(self): + self.label_map = defaultdict(list) + + for i, label in enumerate(self.labels): + self.label_map[label].append(i) + + self.num_labeled = len(self.label_map) + + if -1 in self.label_map: + self.num_labeled -= 1 + for i, j in zip(range(self.num_labeled, + self.num_labeled + len(self.label_map[-1])), + self.label_map[-1]): + self.label_map[i].append(j) + del self.label_map[-1] + + def build_cost_map(self): + N = len(self.label_map) + self.active_clusters = set(range(N)) + self.next_index = N + + for i in range(N): + for j in range(i + 1, N): + i_indexes, j_indexes = self.label_map[i], self.label_map[j] + + if i < self.num_labeled and j < self.num_labeled: + self.cost_map[(i, j)] = -np.inf + continue + + self.cost_map[(i, j)] = self.compute_cost(i_indexes, j_indexes) + + factor = len(i_indexes) * len(j_indexes) + normalized_cost = self.cost_map[(i, j)] / factor + if normalized_cost >= self.merge_cutoff: + heapq.heappush(self.heap, (-normalized_cost, (i, j))) + + def compute_cost(self, i_indexes, j_indexes): + i_embedding = sum([ + self.l2norm(self.embeddings[i_index]) for i_index in i_indexes]) + j_embedding = sum([ + self.l2norm(self.embeddings[j_index]) for j_index in j_indexes]) + return np.dot(i_embedding, j_embedding) + + def merge(self, i, j): + i_indexes, j_indexes = self.label_map[i], self.label_map[j] + + for k, _ in self.label_map.items(): + if k == i or k == j: + continue + pair1 = (k, i) if k < i else (i, k) + pair2 = (k, j) if k < j else (j, k) + cost = self.cost_map[pair1] + self.cost_map[pair2] + self.cost_map[(k, self.next_index)] = cost + + factor = (len(i_indexes) + len(j_indexes)) * len(self.label_map[k]) + normalized_cost = cost / factor + if normalized_cost >= self.merge_cutoff: + heapq.heappush(self.heap, (-normalized_cost, + (k, self.next_index))) + + self.label_map[self.next_index] = i_indexes + j_indexes + self.active_clusters.add(self.next_index) + self.eliminate(i) + self.eliminate(j) + self.next_index += 1 + + def l2norm(self, x, axis=0, keepdims=True): + return x / np.linalg.norm(x, axis=axis, keepdims=keepdims) def get_args(): @@ -93,9 +238,9 @@ def cluster(embeddings, n_neighbors=16, min_dist=0.05): approx_min_span_tree=False, core_dist_n_jobs=1).fit_predict(umap_embeddings) - labels = pahc.PAHC(merge_cutoff=0.3, - min_cluster_size=3, - absorb_cutoff=0.0).fit_predict(labels, embeddings) + labels = PAHC(merge_cutoff=0.3, + min_cluster_size=3, + absorb_cutoff=0.0).fit_predict(labels, embeddings) return labels From f894bb265fcecd1fa09af1891819ce3707e013c1 Mon Sep 17 00:00:00 2001 From: xx205 Date: Mon, 19 Aug 2024 16:20:24 +0000 Subject: [PATCH 08/12] link to local and path.sh; update requirements.txt and extract_emb.py --- examples/voxconverse/v3/local | 1 + examples/voxconverse/v3/local/extract_emb.sh | 63 -------------------- examples/voxconverse/v3/local/make_fbank.sh | 52 ---------------- examples/voxconverse/v3/path.sh | 6 +- requirements.txt | 1 + wespeaker/diar/extract_emb.py | 1 + 6 files changed, 4 insertions(+), 120 deletions(-) create mode 120000 examples/voxconverse/v3/local delete mode 100755 examples/voxconverse/v3/local/extract_emb.sh delete mode 100755 examples/voxconverse/v3/local/make_fbank.sh mode change 100644 => 120000 examples/voxconverse/v3/path.sh diff --git a/examples/voxconverse/v3/local b/examples/voxconverse/v3/local new file mode 120000 index 00000000..8b1d5f97 --- /dev/null +++ b/examples/voxconverse/v3/local @@ -0,0 +1 @@ +../v2/local \ No newline at end of file diff --git a/examples/voxconverse/v3/local/extract_emb.sh b/examples/voxconverse/v3/local/extract_emb.sh deleted file mode 100755 index b12a1c04..00000000 --- a/examples/voxconverse/v3/local/extract_emb.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash -# Copyright (c) 2022 Zhengyang Chen (chenzhengyang117@gmail.com) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -. ./path.sh || exit 1 - -scp='' -pretrained_model='' -device=cuda -store_dir='' -subseg_cmn=true -nj=1 - -batch_size=96 -frame_shift=10 -window_secs=1.5 -period_secs=0.75 - -. tools/parse_options.sh - -split_dir=$store_dir/split_scp -log_dir=$store_dir/log -mkdir -p $split_dir -mkdir -p $log_dir - -# split the scp file to sub_file, and we can use multi-process to extract embeddings -file_len=`wc -l $scp | awk '{print $1}'` -subfile_len=$[$file_len / $nj + 1] -prefix='split' -split -l $subfile_len -d -a 3 $scp ${split_dir}/${prefix}_scp_ - -for suffix in `seq 0 $[$nj-1]`;do - suffix=`printf '%03d' $suffix` - scp_subfile=${split_dir}/${prefix}_scp_${suffix} - write_ark=$store_dir/emb_${suffix}.ark - python3 wespeaker/diar/extract_emb.py \ - --scp ${scp_subfile} \ - --ark-path ${write_ark} \ - --source ${pretrained_model} \ - --device ${device} \ - --batch-size ${batch_size} \ - --frame-shift ${frame_shift} \ - --window-secs ${window_secs} \ - --period-secs ${period_secs} \ - --subseg-cmn ${subseg_cmn} \ - > ${log_dir}/${prefix}.${suffix}.log 2>&1 & -done - -wait - -cat $store_dir/emb_*.scp > $store_dir/emb.scp -echo "Finish extract embedding." diff --git a/examples/voxconverse/v3/local/make_fbank.sh b/examples/voxconverse/v3/local/make_fbank.sh deleted file mode 100755 index 3224c151..00000000 --- a/examples/voxconverse/v3/local/make_fbank.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash -# Copyright (c) 2022 Zhengyang Chen (chenzhengyang117@gmail.com) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -. ./path.sh || exit 1 - -scp='' -segments='' -store_dir='' -subseg_cmn=true -nj=1 - -. tools/parse_options.sh - -split_dir=$store_dir/split_scp -log_dir=$store_dir/log -mkdir -p $split_dir -mkdir -p $log_dir - -# split the scp file to sub_file, and we can use multi-process to extract Fbank feature -file_len=`wc -l $scp | awk '{print $1}'` -subfile_len=$[$file_len / $nj + 1] -prefix='split' -split -l $subfile_len -d -a 3 $scp ${split_dir}/${prefix}_scp_ - -for suffix in `seq 0 $[$nj-1]`;do - suffix=`printf '%03d' $suffix` - scp_subfile=${split_dir}/${prefix}_scp_${suffix} - write_ark=$store_dir/fbank_${suffix}.ark - python3 wespeaker/diar/make_fbank.py \ - --scp ${scp_subfile} \ - --segments ${segments} \ - --ark-path ${write_ark} \ - --subseg-cmn ${subseg_cmn} \ - > ${log_dir}/${prefix}.${suffix}.log 2>&1 & -done - -wait - -cat $store_dir/fbank_*.scp > $store_dir/fbank.scp -echo "Finish make Fbank." diff --git a/examples/voxconverse/v3/path.sh b/examples/voxconverse/v3/path.sh deleted file mode 100644 index b90a5154..00000000 --- a/examples/voxconverse/v3/path.sh +++ /dev/null @@ -1,5 +0,0 @@ -export PATH=$PWD:$PATH - -# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C -export PYTHONIOENCODING=UTF-8 -export PYTHONPATH=../../../:$PYTHONPATH diff --git a/examples/voxconverse/v3/path.sh b/examples/voxconverse/v3/path.sh new file mode 120000 index 00000000..b6a713c8 --- /dev/null +++ b/examples/voxconverse/v3/path.sh @@ -0,0 +1 @@ +../v2/path.sh \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 455fb1e1..e765ff2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,5 +23,6 @@ soundfile==0.10.3.post1 pypeln==0.4.9 silero-vad pre-commit==3.5.0 +s3prl hdbscan==0.8.37 umap-learn==0.5.6 diff --git a/wespeaker/diar/extract_emb.py b/wespeaker/diar/extract_emb.py index 54f98880..7170e7c7 100644 --- a/wespeaker/diar/extract_emb.py +++ b/wespeaker/diar/extract_emb.py @@ -37,6 +37,7 @@ def init_session(source, device): opts = ort.SessionOptions() opts.inter_op_num_threads = 1 opts.intra_op_num_threads = 1 + opts.log_severity_level=0 session = ort.InferenceSession(source, sess_options=opts, providers=providers) From 69d213490e433bc01c6ead56fdc14a48aa34fb42 Mon Sep 17 00:00:00 2001 From: xx205 Date: Tue, 20 Aug 2024 00:43:30 +0800 Subject: [PATCH 09/12] fix lint error: extract_emb.py --- wespeaker/diar/extract_emb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wespeaker/diar/extract_emb.py b/wespeaker/diar/extract_emb.py index 7170e7c7..da16fc85 100644 --- a/wespeaker/diar/extract_emb.py +++ b/wespeaker/diar/extract_emb.py @@ -37,7 +37,7 @@ def init_session(source, device): opts = ort.SessionOptions() opts.inter_op_num_threads = 1 opts.intra_op_num_threads = 1 - opts.log_severity_level=0 + opts.log_severity_level = 0 session = ort.InferenceSession(source, sess_options=opts, providers=providers) From 03c0e48f889b849605b1a54c322db0d975b01562 Mon Sep 17 00:00:00 2001 From: xx205 Date: Tue, 20 Aug 2024 12:42:45 +0800 Subject: [PATCH 10/12] Update README.md Update News section in README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index de8659d3..9474a7ed 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ pre-commit install # for clean and tidy code ``` ## 🔥 News +* 2024.08.20: Update diarization recipe of VoxConverse dataset to v3, see [#347](https://github.com/wenet-e2e/wespeaker/pull/347). * 2024.08.18: Support using ssl pre-trained models as the frontend. The [WavLM recipe](https://github.com/wenet-e2e/wespeaker/blob/master/examples/voxceleb/v2/run_wavlm.sh) is also provided, see [#344](https://github.com/wenet-e2e/wespeaker/pull/344). * 2024.05.15: Add support for [quality-aware score calibration](https://arxiv.org/pdf/2211.00815), see [#320](https://github.com/wenet-e2e/wespeaker/pull/320). * 2024.04.25: Add support for the gemini-dfresnet model, see [#291](https://github.com/wenet-e2e/wespeaker/pull/291). From a33c1cee72edddc65bb27dc4938ecf6853489040 Mon Sep 17 00:00:00 2001 From: xx205 Date: Tue, 20 Aug 2024 12:52:49 +0800 Subject: [PATCH 11/12] Update voxconverse/v3/README.md Update clustering method --- examples/voxconverse/v3/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/voxconverse/v3/README.md b/examples/voxconverse/v3/README.md index c86e6e34..5b333714 100644 --- a/examples/voxconverse/v3/README.md +++ b/examples/voxconverse/v3/README.md @@ -6,7 +6,7 @@ * Refer to [voxceleb sv recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v2) * [pretrained model path](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx) * Speaker activity detection model: oracle SAD (from ground truth annotation) or system SAD (VAD model pretrained by silero, https://github.com/snakers4/silero-vad) -* Clustering method: spectral clustering +* Clustering method: umap dimensionality reduction + hdbscan clustering * Metric: DER = MISS + FALSE ALARM + SPEAKER CONFUSION (%) ## Results From 78e52f8ebcb2382dfefcafde09ad5c7eed9579b9 Mon Sep 17 00:00:00 2001 From: Zhengyang Chen Date: Tue, 20 Aug 2024 12:57:55 +0800 Subject: [PATCH 12/12] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9474a7ed..afb05573 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ pre-commit install # for clean and tidy code ``` ## 🔥 News -* 2024.08.20: Update diarization recipe of VoxConverse dataset to v3, see [#347](https://github.com/wenet-e2e/wespeaker/pull/347). +* 2024.08.20: Update diarization recipe for VoxConverse dataset by leveraging umap dimensionality reduction and hdbscan clustering, see [#347](https://github.com/wenet-e2e/wespeaker/pull/347). * 2024.08.18: Support using ssl pre-trained models as the frontend. The [WavLM recipe](https://github.com/wenet-e2e/wespeaker/blob/master/examples/voxceleb/v2/run_wavlm.sh) is also provided, see [#344](https://github.com/wenet-e2e/wespeaker/pull/344). * 2024.05.15: Add support for [quality-aware score calibration](https://arxiv.org/pdf/2211.00815), see [#320](https://github.com/wenet-e2e/wespeaker/pull/320). * 2024.04.25: Add support for the gemini-dfresnet model, see [#291](https://github.com/wenet-e2e/wespeaker/pull/291).