-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #622 from nextcloud/hdbscan
- Loading branch information
Showing
11 changed files
with
2,554 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
<?php | ||
/* | ||
* Copyright (c) 2023 The Recognize contributors. | ||
* This file is licensed under the Affero General Public License version 3 or later. See the COPYING file. | ||
*/ | ||
|
||
namespace OCA\Recognize\Clustering; | ||
|
||
use Rubix\ML\Datasets\Labeled; | ||
use Rubix\ML\Graph\Nodes\Ball; | ||
use Rubix\ML\Helpers\Stats; | ||
use Rubix\ML\Kernels\Distance\Distance; | ||
use function Rubix\ML\argmax; | ||
|
||
class DualTreeBall extends Ball { | ||
protected float $longestDistanceInNode = INF; | ||
protected bool $fullyConnected = false; | ||
protected $setId; | ||
|
||
|
||
public function setLongestDistance($longestDistance): void { | ||
if ($longestDistance < $this->longestDistanceInNode) { | ||
$this->longestDistanceInNode = $longestDistance; | ||
} | ||
} | ||
|
||
public function getLongestDistance(): float { | ||
return $this->longestDistanceInNode; | ||
} | ||
|
||
public function resetLongestEdge(): void { | ||
$this->longestDistanceInNode = INF; | ||
foreach ($this->children() as $child) { | ||
$child->resetLongestEdge(); | ||
} | ||
} | ||
|
||
public function resetFullyConnectedStatus(): void { | ||
$this->fullyConnected = false; | ||
foreach ($this->children() as $child) { | ||
$child->resetFullyConnectedStatus(); | ||
} | ||
} | ||
|
||
/** | ||
* @return null|int|string | ||
*/ | ||
public function getSetId() { | ||
if (!$this->fullyConnected) { | ||
return null; | ||
} | ||
return $this->setId; | ||
} | ||
|
||
public function isFullyConnected(): bool { | ||
return $this->fullyConnected; | ||
} | ||
|
||
/** | ||
* @param array $labelToSetId | ||
* @return null|int|string | ||
*/ | ||
public function propagateSetChanges(array &$labelToSetId) { | ||
if ($this->fullyConnected) { | ||
// If we are already fully connected, we just need to check if the | ||
// set id has changed | ||
foreach ($this->children() as $child) { | ||
$this->setId = $child->propagateSetChanges($labelToSetId); | ||
} | ||
|
||
return $this->setId; | ||
} | ||
|
||
// If, and only if, both children are fully connected and in the same set id then | ||
// we, too, are fully connected | ||
$setId = null; | ||
foreach ($this->children() as $child) { | ||
$retVal = $child->propagateSetChanges($labelToSetId); | ||
|
||
if ($retVal === null) { | ||
return null; | ||
} | ||
|
||
if ($setId !== null && $setId !== $retVal) { | ||
return null; | ||
} | ||
|
||
$setId = $retVal; | ||
} | ||
|
||
$this->setId = $setId; | ||
$this->fullyConnected = true; | ||
|
||
return $this->setId; | ||
} | ||
|
||
/** | ||
* Factory method to build a hypersphere by splitting the dataset into left and right clusters. | ||
* | ||
* @param \Rubix\ML\Datasets\Labeled $dataset | ||
* @param \Rubix\ML\Kernels\Distance\Distance $kernel | ||
* @return self | ||
*/ | ||
public static function split(Labeled $dataset, Distance $kernel): self { | ||
$center = []; | ||
|
||
foreach ($dataset->features() as $column => $values) { | ||
if ($dataset->featureType($column)->isContinuous()) { | ||
$center[] = Stats::mean($values); | ||
} else { | ||
$center[] = argmax(array_count_values($values)); | ||
} | ||
} | ||
|
||
$distances = []; | ||
|
||
foreach ($dataset->samples() as $sample) { | ||
$distances[] = $kernel->compute($sample, $center); | ||
} | ||
|
||
$radius = max($distances) ?: 0.0; | ||
|
||
$leftCentroid = $dataset->sample(argmax($distances)); | ||
|
||
$distances = []; | ||
|
||
foreach ($dataset->samples() as $sample) { | ||
$distances[] = $kernel->compute($sample, $leftCentroid); | ||
} | ||
|
||
$rightCentroid = $dataset->sample(argmax($distances)); | ||
|
||
$subsets = $dataset->spatialSplit($leftCentroid, $rightCentroid, $kernel); | ||
|
||
return new self($center, $radius, $subsets); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
<?php | ||
/* | ||
* Copyright (c) 2023 The Recognize contributors. | ||
* This file is licensed under the Affero General Public License version 3 or later. See the COPYING file. | ||
*/ | ||
|
||
namespace OCA\Recognize\Clustering; | ||
|
||
use Rubix\ML\Datasets\Labeled; | ||
use Rubix\ML\Graph\Nodes\Clique; | ||
use Rubix\ML\Helpers\Stats; | ||
use Rubix\ML\Kernels\Distance\Distance; | ||
use function Rubix\ML\argmax; | ||
|
||
class DualTreeClique extends Clique { | ||
protected float $longestDistanceInNode = INF; | ||
protected bool $fullyConnected = false; | ||
/** | ||
* @var null|int|string | ||
*/ | ||
protected $setId; | ||
|
||
public function setLongestDistance(float $longestDistance): void { | ||
$this->longestDistanceInNode = $longestDistance; | ||
} | ||
|
||
public function getLongestDistance(): float { | ||
return $this->longestDistanceInNode; | ||
} | ||
|
||
public function resetLongestEdge(): void { | ||
$this->longestDistanceInNode = INF; | ||
} | ||
|
||
public function resetFullyConnectedStatus(): void { | ||
$this->fullyConnected = false; | ||
} | ||
|
||
/** | ||
* @return int|string|null | ||
*/ | ||
public function getSetId() { | ||
if (!$this->fullyConnected) { | ||
return null; | ||
} | ||
|
||
return $this->setId; | ||
} | ||
|
||
public function isFullyConnected(): bool { | ||
return $this->fullyConnected; | ||
} | ||
|
||
/** | ||
* @param array<int|string,int|string> $labelToSetId | ||
* @return int|mixed|string|null | ||
*/ | ||
public function propagateSetChanges(array &$labelToSetId) { | ||
if ($this->fullyConnected) { | ||
$this->setId = $labelToSetId[$this->dataset->label(0)]; | ||
return $this->setId; | ||
} | ||
|
||
$labels = $this->dataset->labels(); | ||
|
||
$setId = $labelToSetId[array_pop($labels)]; | ||
|
||
foreach ($labels as $label) { | ||
if ($setId !== $labelToSetId[$label]) { | ||
return null; | ||
} | ||
} | ||
|
||
$this->fullyConnected = true; | ||
$this->setId = $setId; | ||
|
||
return $this->setId; | ||
} | ||
|
||
/** | ||
* Terminate a branch with a dataset. | ||
* | ||
* @param \Rubix\ML\Datasets\Labeled $dataset | ||
* @param \Rubix\ML\Kernels\Distance\Distance $kernel | ||
* @return self | ||
*/ | ||
public static function terminate(Labeled $dataset, Distance $kernel): self { | ||
$center = []; | ||
|
||
foreach ($dataset->features() as $column => $values) { | ||
if ($dataset->featureType($column)->isContinuous()) { | ||
$center[] = Stats::mean($values); | ||
} else { | ||
$center[] = argmax(array_count_values($values)); | ||
} | ||
} | ||
|
||
$distances = []; | ||
|
||
foreach ($dataset->samples() as $sample) { | ||
$distances[] = $kernel->compute($sample, $center); | ||
} | ||
$radius = max($distances) ?: 0.0; | ||
return new self($dataset, $center, $radius); | ||
} | ||
} |
Oops, something went wrong.