Skip to content

Commit

Permalink
Make ThresholdType an enum
Browse files Browse the repository at this point in the history
  • Loading branch information
wq2012 committed Aug 18, 2021
1 parent 2f7b524 commit d7054de
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 35 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ build/*
dist/*
spectralcluster.egg-info/*
.coverage
.DS_Store
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ You can specify your refinment operations like this:

```
from spectralcluster import RefinementOptions
from spectralcluster import ThresholdType
from spectralcluster import ICASSP2018_REFINEMENT_SEQUENCE
refinement_options = RefinementOptions(
gaussian_blur_sigma=1,
p_percentile=0.95,
thresholding_soft_multiplier=0.01,
thresholding_with_row_max=True,
thresholding_type=ThresholdType.RowMax,
refinement_sequence=ICASSP2018_REFINEMENT_SEQUENCE)
```

Expand All @@ -116,8 +117,8 @@ In the new version of this library, we support different types of Laplacian matr

* None Laplacian (affinity matrix): `W`
* Unnormalized Laplacian: `L = D - W`
* Graph cut Laplacian: `L' = D^{-1/2} L D^{-1/2}`
* Random walk Laplacian: `L' = D^{-1} L`
* Graph cut Laplacian: `L' = D^{-1/2} * L * D^{-1/2}`
* Random walk Laplacian: `L' = D^{-1} * L`

You can specify the Laplacian matrix type with the `laplacian_type` argument of the `SpectralClusterer` class.

Expand Down
4 changes: 3 additions & 1 deletion docs/configs.html
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ <h1 class="title">Module <code>spectralcluster.configs</code></h1>

RefinementName = refinement.RefinementName
RefinementOptions = refinement.RefinementOptions
ThresholdType = refinement.ThresholdType
SymmetrizeType = refinement.SymmetrizeType
SpectralClusterer = spectral_clusterer.SpectralClusterer


Expand All @@ -52,7 +54,7 @@ <h1 class="title">Module <code>spectralcluster.configs</code></h1>
gaussian_blur_sigma=1,
p_percentile=0.95,
thresholding_soft_multiplier=0.01,
thresholding_with_row_max=True,
thresholding_type=ThresholdType.RowMax,
refinement_sequence=ICASSP2018_REFINEMENT_SEQUENCE)

icassp2018_clusterer = SpectralClusterer(
Expand Down
1 change: 1 addition & 0 deletions docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ <h1 class="title">Package <code>spectralcluster</code></h1>

RefinementName = refinement.RefinementName
RefinementOptions = refinement.RefinementOptions
ThresholdType = refinement.ThresholdType
SymmetrizeType = refinement.SymmetrizeType

SpectralClusterer = spectral_clusterer.SpectralClusterer
Expand Down
115 changes: 85 additions & 30 deletions docs/refinement.html
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
RowWiseNormalize = 6


class ThresholdType(enum.Enum):
&#34;&#34;&#34;Different types of thresholding.&#34;&#34;&#34;
# We clear values that are smaller than row_max*p_percentile
RowMax = 1

# We clear (p_percentile*100)% smallest values of the entire row
Percentile = 2


class SymmetrizeType(enum.Enum):
&#34;&#34;&#34;Different types of symmetrization operation.&#34;&#34;&#34;
# We use max(A, A^T)
Expand All @@ -61,7 +70,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
gaussian_blur_sigma=1,
p_percentile=0.95,
thresholding_soft_multiplier=0.01,
thresholding_with_row_max=True,
thresholding_type=ThresholdType.RowMax,
thresholding_with_binarization=False,
thresholding_preserve_diagonal=False,
symmetrize_type=SymmetrizeType.Max,
Expand All @@ -73,8 +82,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
p_percentile: the p-percentile for the row wise thresholding
thresholding_soft_multiplier: the multiplier for soft threhsold, if this
value is 0, then it&#39;s a hard thresholding
thresholding_with_row_max: if true, we use row_max * p_percentile as row
wise threshold, instead of doing a percentile-based thresholding
thresholding_type: the type of thresholding operation
thresholding_with_binarization: if true, we set values larger than the
threshold to 1
thresholding_preserve_diagonal: if true, in the row wise thresholding
Expand All @@ -88,7 +96,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
self.gaussian_blur_sigma = gaussian_blur_sigma
self.p_percentile = p_percentile
self.thresholding_soft_multiplier = thresholding_soft_multiplier
self.thresholding_with_row_max = thresholding_with_row_max
self.thresholding_type = thresholding_type
self.thresholding_with_binarization = thresholding_with_binarization
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal
self.symmetrize_type = symmetrize_type
Expand Down Expand Up @@ -121,7 +129,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
elif name == RefinementName.RowWiseThreshold:
return RowWiseThreshold(self.p_percentile,
self.thresholding_soft_multiplier,
self.thresholding_with_row_max,
self.thresholding_type,
self.thresholding_with_binarization,
self.thresholding_preserve_diagonal)
elif name == RefinementName.Symmetrize:
Expand Down Expand Up @@ -203,12 +211,14 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
def __init__(self,
p_percentile=0.95,
thresholding_soft_multiplier=0.01,
thresholding_with_row_max=False,
thresholding_type=ThresholdType.RowMax,
thresholding_with_binarization=False,
thresholding_preserve_diagonal=False):
self.p_percentile = p_percentile
self.multiplier = thresholding_soft_multiplier
self.thresholding_with_row_max = thresholding_with_row_max
if not isinstance(thresholding_type, ThresholdType):
raise TypeError(&#34;thresholding_type must be a ThresholdType&#34;)
self.thresholding_type = thresholding_type
self.thresholding_with_binarization = thresholding_with_binarization
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal

Expand All @@ -217,17 +227,19 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
refined_affinity = np.copy(affinity)
if self.thresholding_preserve_diagonal:
np.fill_diagonal(refined_affinity, 0.0)
if self.thresholding_with_row_max:
if self.thresholding_type == ThresholdType.RowMax:
# Row_max based thresholding
row_max = refined_affinity.max(axis=1)
row_max = np.expand_dims(row_max, axis=1)
is_smaller = refined_affinity &lt; (row_max * self.p_percentile)
else:
elif self.thresholding_type == ThresholdType.Percentile:
# Percentile based thresholding
row_percentile = np.percentile(
refined_affinity, self.p_percentile * 100, axis=1)
row_percentile = np.expand_dims(row_percentile, axis=1)
is_smaller = refined_affinity &lt; row_percentile
else:
raise ValueError(&#34;Unsupported thresholding_type&#34;)
if self.thresholding_with_binarization:
# For values larger than the threshold, we binarize them to 1
refined_affinity = (np.ones_like(
Expand All @@ -245,13 +257,13 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
&#34;&#34;&#34;The Symmetrization operation.&#34;&#34;&#34;

def __init__(self, symmetrize_type=SymmetrizeType.Max):
if not isinstance(symmetrize_type, SymmetrizeType):
raise TypeError(&#34;symmetrize_type must be a SymmetrizeType&#34;)
self.symmetrize_type = symmetrize_type

def refine(self, affinity):
self.check_input(affinity)
if not isinstance(self.symmetrize_type, SymmetrizeType):
raise TypeError(&#34;symmetrize_type must be a SymmetrizeType&#34;)
elif self.symmetrize_type == SymmetrizeType.Max:
if self.symmetrize_type == SymmetrizeType.Max:
return np.maximum(affinity, np.transpose(affinity))
elif self.symmetrize_type == SymmetrizeType.Average:
return 0.5 * (affinity + np.transpose(affinity))
Expand Down Expand Up @@ -572,7 +584,7 @@ <h3>Class variables</h3>
</dd>
<dt id="spectralcluster.refinement.RefinementOptions"><code class="flex name class">
<span>class <span class="ident">RefinementOptions</span></span>
<span>(</span><span>gaussian_blur_sigma=1, p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_with_row_max=True, thresholding_with_binarization=False, thresholding_preserve_diagonal=False, symmetrize_type=SymmetrizeType.Max, refinement_sequence=None)</span>
<span>(</span><span>gaussian_blur_sigma=1, p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_type=ThresholdType.RowMax, thresholding_with_binarization=False, thresholding_preserve_diagonal=False, symmetrize_type=SymmetrizeType.Max, refinement_sequence=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Refinement options for the affinity matrix.</p>
Expand All @@ -586,9 +598,8 @@ <h2 id="args">Args</h2>
<dt><strong><code>thresholding_soft_multiplier</code></strong></dt>
<dd>the multiplier for soft threhsold, if this
value is 0, then it's a hard thresholding</dd>
<dt><strong><code>thresholding_with_row_max</code></strong></dt>
<dd>if true, we use row_max * p_percentile as row
wise threshold, instead of doing a percentile-based thresholding</dd>
<dt><strong><code>thresholding_type</code></strong></dt>
<dd>the type of thresholding operation</dd>
<dt><strong><code>thresholding_with_binarization</code></strong></dt>
<dd>if true, we set values larger than the
threshold to 1</dd>
Expand All @@ -614,7 +625,7 @@ <h2 id="args">Args</h2>
gaussian_blur_sigma=1,
p_percentile=0.95,
thresholding_soft_multiplier=0.01,
thresholding_with_row_max=True,
thresholding_type=ThresholdType.RowMax,
thresholding_with_binarization=False,
thresholding_preserve_diagonal=False,
symmetrize_type=SymmetrizeType.Max,
Expand All @@ -626,8 +637,7 @@ <h2 id="args">Args</h2>
p_percentile: the p-percentile for the row wise thresholding
thresholding_soft_multiplier: the multiplier for soft threhsold, if this
value is 0, then it&#39;s a hard thresholding
thresholding_with_row_max: if true, we use row_max * p_percentile as row
wise threshold, instead of doing a percentile-based thresholding
thresholding_type: the type of thresholding operation
thresholding_with_binarization: if true, we set values larger than the
threshold to 1
thresholding_preserve_diagonal: if true, in the row wise thresholding
Expand All @@ -641,7 +651,7 @@ <h2 id="args">Args</h2>
self.gaussian_blur_sigma = gaussian_blur_sigma
self.p_percentile = p_percentile
self.thresholding_soft_multiplier = thresholding_soft_multiplier
self.thresholding_with_row_max = thresholding_with_row_max
self.thresholding_type = thresholding_type
self.thresholding_with_binarization = thresholding_with_binarization
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal
self.symmetrize_type = symmetrize_type
Expand Down Expand Up @@ -674,7 +684,7 @@ <h2 id="args">Args</h2>
elif name == RefinementName.RowWiseThreshold:
return RowWiseThreshold(self.p_percentile,
self.thresholding_soft_multiplier,
self.thresholding_with_row_max,
self.thresholding_type,
self.thresholding_with_binarization,
self.thresholding_preserve_diagonal)
elif name == RefinementName.Symmetrize:
Expand Down Expand Up @@ -733,7 +743,7 @@ <h2 id="raises">Raises</h2>
elif name == RefinementName.RowWiseThreshold:
return RowWiseThreshold(self.p_percentile,
self.thresholding_soft_multiplier,
self.thresholding_with_row_max,
self.thresholding_type,
self.thresholding_with_binarization,
self.thresholding_preserve_diagonal)
elif name == RefinementName.Symmetrize:
Expand Down Expand Up @@ -783,7 +793,7 @@ <h3>Inherited members</h3>
</dd>
<dt id="spectralcluster.refinement.RowWiseThreshold"><code class="flex name class">
<span>class <span class="ident">RowWiseThreshold</span></span>
<span>(</span><span>p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_with_row_max=False, thresholding_with_binarization=False, thresholding_preserve_diagonal=False)</span>
<span>(</span><span>p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_type=ThresholdType.RowMax, thresholding_with_binarization=False, thresholding_preserve_diagonal=False)</span>
</code></dt>
<dd>
<div class="desc"><p>Apply row wise thresholding.</p></div>
Expand All @@ -797,12 +807,14 @@ <h3>Inherited members</h3>
def __init__(self,
p_percentile=0.95,
thresholding_soft_multiplier=0.01,
thresholding_with_row_max=False,
thresholding_type=ThresholdType.RowMax,
thresholding_with_binarization=False,
thresholding_preserve_diagonal=False):
self.p_percentile = p_percentile
self.multiplier = thresholding_soft_multiplier
self.thresholding_with_row_max = thresholding_with_row_max
if not isinstance(thresholding_type, ThresholdType):
raise TypeError(&#34;thresholding_type must be a ThresholdType&#34;)
self.thresholding_type = thresholding_type
self.thresholding_with_binarization = thresholding_with_binarization
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal

Expand All @@ -811,17 +823,19 @@ <h3>Inherited members</h3>
refined_affinity = np.copy(affinity)
if self.thresholding_preserve_diagonal:
np.fill_diagonal(refined_affinity, 0.0)
if self.thresholding_with_row_max:
if self.thresholding_type == ThresholdType.RowMax:
# Row_max based thresholding
row_max = refined_affinity.max(axis=1)
row_max = np.expand_dims(row_max, axis=1)
is_smaller = refined_affinity &lt; (row_max * self.p_percentile)
else:
elif self.thresholding_type == ThresholdType.Percentile:
# Percentile based thresholding
row_percentile = np.percentile(
refined_affinity, self.p_percentile * 100, axis=1)
row_percentile = np.expand_dims(row_percentile, axis=1)
is_smaller = refined_affinity &lt; row_percentile
else:
raise ValueError(&#34;Unsupported thresholding_type&#34;)
if self.thresholding_with_binarization:
# For values larger than the threshold, we binarize them to 1
refined_affinity = (np.ones_like(
Expand Down Expand Up @@ -862,13 +876,13 @@ <h3>Inherited members</h3>
&#34;&#34;&#34;The Symmetrization operation.&#34;&#34;&#34;

def __init__(self, symmetrize_type=SymmetrizeType.Max):
if not isinstance(symmetrize_type, SymmetrizeType):
raise TypeError(&#34;symmetrize_type must be a SymmetrizeType&#34;)
self.symmetrize_type = symmetrize_type

def refine(self, affinity):
self.check_input(affinity)
if not isinstance(self.symmetrize_type, SymmetrizeType):
raise TypeError(&#34;symmetrize_type must be a SymmetrizeType&#34;)
elif self.symmetrize_type == SymmetrizeType.Max:
if self.symmetrize_type == SymmetrizeType.Max:
return np.maximum(affinity, np.transpose(affinity))
elif self.symmetrize_type == SymmetrizeType.Average:
return 0.5 * (affinity + np.transpose(affinity))
Expand Down Expand Up @@ -923,6 +937,40 @@ <h3>Class variables</h3>
</dd>
</dl>
</dd>
<dt id="spectralcluster.refinement.ThresholdType"><code class="flex name class">
<span>class <span class="ident">ThresholdType</span></span>
<span>(</span><span>value, names=None, *, module=None, qualname=None, type=None, start=1)</span>
</code></dt>
<dd>
<div class="desc"><p>Different types of thresholding.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class ThresholdType(enum.Enum):
&#34;&#34;&#34;Different types of thresholding.&#34;&#34;&#34;
# We clear values that are smaller than row_max*p_percentile
RowMax = 1

# We clear (p_percentile*100)% smallest values of the entire row
Percentile = 2</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>enum.Enum</li>
</ul>
<h3>Class variables</h3>
<dl>
<dt id="spectralcluster.refinement.ThresholdType.Percentile"><code class="name">var <span class="ident">Percentile</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="spectralcluster.refinement.ThresholdType.RowMax"><code class="name">var <span class="ident">RowMax</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
</dl>
</dd>
</dl>
</section>
</article>
Expand Down Expand Up @@ -988,6 +1036,13 @@ <h4><code><a title="spectralcluster.refinement.SymmetrizeType" href="#spectralcl
<li><code><a title="spectralcluster.refinement.SymmetrizeType.Max" href="#spectralcluster.refinement.SymmetrizeType.Max">Max</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="spectralcluster.refinement.ThresholdType" href="#spectralcluster.refinement.ThresholdType">ThresholdType</a></code></h4>
<ul class="">
<li><code><a title="spectralcluster.refinement.ThresholdType.Percentile" href="#spectralcluster.refinement.ThresholdType.Percentile">Percentile</a></code></li>
<li><code><a title="spectralcluster.refinement.ThresholdType.RowMax" href="#spectralcluster.refinement.ThresholdType.RowMax">RowMax</a></code></li>
</ul>
</li>
</ul>
</li>
</ul>
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import setuptools

VERSION = "0.2.0"
VERSION = "0.2.1"

with open("README.md", "r") as file_object:
LONG_DESCRIPTION = file_object.read()
Expand Down

0 comments on commit d7054de

Please sign in to comment.