-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ASR] Add optimization util for linear sum assignment algorithm #6349
Conversation
…diarization Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
…fix/clus_spk_util_jit
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
…fix/clus_spk_util_jit
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
…fix/clus_spk_util_jit
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
…fix/clus_spk_util_jit
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor review. Will do thorough review tomorrow.
Very neat improvement, need to understand better from my end.
@@ -552,12 +550,13 @@ def eigDecompose( | |||
device = torch.cuda.current_device() | |||
laplacian = laplacian.float().to(device) | |||
else: | |||
laplacian = laplacian.float().to(torch.device('cpu')) | |||
laplacian = laplacian.float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why same operation twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
laplacian = laplacian.float() | ||
lambdas, diffusion_map = eigh(laplacian) | ||
return lambdas, diffusion_map | ||
|
||
|
||
def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device = torch.device('cpu')) -> torch.Tensor: | ||
def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why cuda and device? Isn't only one sufficient
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was added long back because there are users setting cuda=True but device=cpu.
This is adding some flexibility to avoid errors on such cases.
If we need to remove this, lt requires a speparate PR since this involves whole diarization pipeline.
laplacian = laplacian.float().to(torch.device('cpu')) | ||
laplacian = laplacian.float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here. laplacian.float() twice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
stacked = np.hstack((enc_P, enc_Q)) | ||
cost = -1 * linear_kernel(stacked.T)[spk_count:, :spk_count] | ||
row_ind, col_ind = linear_sum_assignment(cost) | ||
PandQ_list: List[int] = [int(x.item()) for x in PandQ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: mentioning dtype in variable name need to be avoided
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense since types are strictly annotated for jit script functions.
Fixed.
marked (Tensor): 2D matrix containing the marked zeros. | ||
""" | ||
|
||
def __init__(self, cost_matrix): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor, mention the dtype of cost_matrix here. Isn;t it necessary for jit scripting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is no type annotation, jit compiler think of it as torch.Tensor.
So in general if it is not torch.Tensor, type annotation is needed.
Added type annotations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it
if cost_matrix.shape[1] < cost_matrix.shape[0]: | ||
cost_matrix = cost_matrix.T | ||
transposed = True | ||
else: | ||
transposed = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why extra transposed
variable, Use the same col < row condition below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This followed the original implementation in scipy.
If we don't use transposed
variable, we need to create another variable to indicate that foo = cost_matrix.shape[1] < cost_matrix.shape[0]
.
# Copyright (c) 2008 Brian M. Clapper <bmc@clapper.org>, Gael Varoquaux | ||
# Author: Brian M. Clapper, Gael Varoquaux | ||
# License: 3-clause BSD | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have only one optimization algorithm yet? Thinking if we should move other funcs to this file as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can add other algorithms below this. (I mentioned "Linear Sum Assignment solver")
The copyright in the beginning of the code is the convention in the most of the project so I followed
nemo/collections/asr/metrics/der.py
Outdated
for label in ref_labels: | ||
start, end, speaker = label.split() | ||
start, end = float(start), float(end) | ||
# If the current [start, end] interval is latching the last prediction time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
latching -> matching
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the expression (Checked by Elena)
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
…fix/clus_spk_util_jit
@@ -31,67 +31,67 @@ | |||
# https://arxiv.org/pdf/2003.02405.pdf and the implementation from | |||
# https://github.com/tango4j/Auto-Tuning-Spectral-Clustering. | |||
|
|||
from typing import List, Tuple | |||
from typing import List, Set, Tuple |
Check notice
Code scanning / CodeQL
Unused import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…IA#6349) * [ASR] Add optimization utils for cpWER, diarization training, online diarization Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed GPU/CPU issues for clustering Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed unreachable state Signed-off-by: Taejin Park <tango4j@gmail.com> * resolved jit script compile error for lsa algorithm Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed errors and bugs, checked tests Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed docstrings Signed-off-by: Taejin Park <tango4j@gmail.com> * Update changes on test files Signed-off-by: Taejin Park <tango4j@gmail.com> * Refactored functions Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Adding docstrings for the functions in der.py Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed wrong docstrings in der.py Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed a wrong docstring Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed np.array input to Tensor for LSA solver in der.py Signed-off-by: Taejin Park <tango4j@gmail.com> * Added code-QL issues and unit-tests for der.py functions Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed print line in der.py Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed code QL redundant comparison Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed code QL issue Signed-off-by: Taejin Park <tango4j@gmail.com> * Added License for the reference code Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added full license text of the original code Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reflected comments Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reflected review comments Signed-off-by: Taejin Park <tango4j@gmail.com> --------- Signed-off-by: Taejin Park <tango4j@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
What does this PR do ?
LSA problem solver is needed for the following tasks in NeMo:
(1) Permutation Invariant Loss (PIL) for diarization model training
(2) Label permutation matching for online speaker diarzation
(3) Concatenated minimum-permutation Word Error Rate (cp-WER) calculation
What is LSA solver algorithm? Google OR-tools LSA Solver
The NeMo linear_sum_assignment function is compared with
scipy.optimization.linear_sum_assingment
.In the unit-test for NeMo LSA solver, the result is compared with the
scipy
version oflinear_sum_assignment
.Removing
@torch.jit.script
decorator in speaker_utils.py since it creates type-errors when the code is not used for production purpose.Instead, all
torch.jit.script
required classes and functions are tested intest_diar_utils.py
.Take a look at these tests for checking jit_script = [True/False] and cuda = [True/False] (testing total 4 combinations)
Also refactored some of the functions in online diarization
online_clustering.py
.Added a couple of functions in
der.py
for online diarization DER calculation.der.py
.Collection: [ASR]
Changelog
nemo/collections/asr/metrics/der.py
: replaced scipy LSA solver to NeMo LSA solver in
calculate_session_cpWER
function.: Added two functions for online diarization evaluations:
get_partial_ref_labels
andget_online_DER_stats
.nemo/collections/asr/models/online_diarizer.py
: Made
_perform_online_clustering
function simpler by movingget_reduced_mat
andmatch_labels
into online clustering function.nemo/collections/asr/parts/utils/offline_clustering.py
: Added
laplacian = laplacian.float().to(torch.device('cpu'))
to avoid jit-scripted module uses GPU even when CPU is specified or vice-versa. This behavior is always tested/checked intest_diar_utils.py
.nemo/collections/asr/parts/utils/online_clustering.py
: replaced scipy LSA solver to NeMo LSA solver in
get_lsa_speaker_mapping
function.: Modified the docstrings of
update_speaker_history_buffer
to make the example easier.nemo/collections/asr/parts/utils/optimization_utils.py
: Fully torch-jit-scriptable, linear sum assignment problem solver class and function were added.
nemo/collections/asr/parts/utils/speaker_utils.py
: Removed
@torch.jit.script
decorators since this creates unnecessary warning messages and type related errors when used without scripting.tests/collections/asr/test_diar_metrics.py
: Added unit-tests for the newly added function
get_partial_ref_labels
andget_online_DER_stats
.tests/collections/asr/test_diar_utils.py
: Added tests for offline clustering and online clustering for many different cases including:
[jit-script=True, cuda=True],
[jit-script=True, cuda=False],
[jit-script=False, cuda=True],
[jit-script=False, cuda=False] cases
which is using the torch-jit-scripted NeMo linear_sum_assignment function.
Usage
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.