Skip to content

Commit

Permalink
Adds documentation for some of the classes and methods under function…
Browse files Browse the repository at this point in the history
…al.py and mapping.py
  • Loading branch information
zaouk committed Apr 5, 2022
1 parent 9e51718 commit 7cf7234
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/diart/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,19 @@ def __init__(
metric: Optional[str] = "cosine",
max_speakers: int = 20
):
"""Initializes an object for constrained incremental online clustering of speakers
Args:
tau_active (float): Threshold for detecting active speakers. This threshold is applied on the maximum value of each output
activation of the local segmentation model.
rho_update (float): Threshold for considering the extracted embedding when updating the centroid of the local speaker.
The centroid to which a local speaker is mapped is only updated if the duration of speech of a
local speaker is greater than this threshold.
delta_new (float): Threshold on the distance between a speaker embedding and a centroid. If the distance between a local speaker and all
centroids is larger than delta_new, then a new centroid is created for the current speaker.
metric (Optional[str], optional): The distance metric to use. Defaults to "cosine".
max_speakers (int, optional): Maximum number of global speakers to track through a conversation. Defaults to 20.
"""
self.tau_active = tau_active
self.rho_update = rho_update
self.delta_new = delta_new
Expand Down Expand Up @@ -323,17 +336,36 @@ def get_next_center_position(self) -> Optional[int]:
return center

def init_centers(self, dimension: int):
"""Initializes the global speaker centers
Args:
dimension (int): dimension of embeddings used for representing a speaker
"""
self.centers = np.zeros((self.max_speakers, dimension))
self.active_centers = set()
self.blocked_centers = set()

def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray):
"""Updates the centroids of the speaker clusters given a list of assignments and embeddigns of local speakers
Args:
assignments (Iterable[Tuple[int, int]]): An iterable of tuples for assigning each local speaker to one global speaker.
embeddings (np.ndarray): embeddings extracted for each of the local speakers that are within the assignment list.
"""
if self.centers is not None:
for l_spk, g_spk in assignments:
assert g_spk in self.active_centers, "Cannot update unknown centers"
self.centers[g_spk] += embeddings[l_spk]

def add_center(self, embedding: np.ndarray) -> int:
"""Define a new cluster centroid from a given embedding vector
Args:
embedding (np.ndarray): Embedding vector of some local speaker
Returns:
int: index of the created center
"""
center = self.get_next_center_position()
self.centers[center] = embedding
self.active_centers.add(center)
Expand All @@ -344,6 +376,16 @@ def identify(
segmentation: SlidingWindowFeature,
embeddings: torch.Tensor
) -> SpeakerMap:
""" Identify the clusters to which the input speaker embeddings belong.
Args:
segmentation (numpy.ndarray): Matrix of segmentation outputs of shape (n_frames, n_local_speakers)
embeddings (torch.Tensor): Matrix of embeddings of shape (n_local_speakers, dim_embedding)
Returns:
SpeakerMap: a SpeakerMap object wrapping a mapping matrix between local speakers and the
global speaker centroids maintained by the OnlineSpeakerClustering object.
"""
embeddings = embeddings.detach().cpu().numpy()
active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0]
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0]
Expand Down
20 changes: 20 additions & 0 deletions src/diart/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]:
def hard_speaker_map(
self, num_src: int, num_tgt: int, assignments: Iterable[Tuple[int, int]]
) -> SpeakerMap:
"""Returns a SpeakerMap object based on the given assignments.
Args:
num_src (int): Number of source speakers (also called local speakers)
num_tgt (int): Number of target speakers (also called global speakers)
assignments (Iterable[Tuple[int, int]]): An iterable of tuples for assigning each local speaker to one global speaker.
Returns:
SpeakerMap: A SpeakerMap object wrapping a mapping matrix between local speakers and global speakers.
"""
mapping_matrix = self.invalid_tensor(shape=(num_src, num_tgt))
for src, tgt in assignments:
mapping_matrix[src, tgt] = self.best_possible_value
Expand Down Expand Up @@ -82,6 +92,16 @@ class SpeakerMapBuilder:
def hard_map(
shape: Tuple[int, int], assignments: Iterable[Tuple[int, int]], maximize: bool
) -> SpeakerMap:
"""Returns a SpeakerMap object based on the given assignments.
Args:
shape (Tuple[int, int]): shape of the mapping matrix
assignments (Iterable[Tuple[int, int]]): An iterable of tuples for assigning each local speaker to one global speaker.
maximize (bool): whether or not to use a MaximizationObjective
Returns:
SpeakerMap: A SpeakerMap object.
"""
num_src, num_tgt = shape
objective = MaximizationObjective if maximize else MinimizationObjective
return objective().hard_speaker_map(num_src, num_tgt, assignments)
Expand Down

0 comments on commit 7cf7234

Please sign in to comment.