Skip to content

Commit

Permalink
Merge pull request #622 from nextcloud/hdbscan
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelklehr authored Feb 6, 2023
2 parents d8719ff + 75586ac commit 2e52fcd
Show file tree
Hide file tree
Showing 11 changed files with 2,554 additions and 135 deletions.
137 changes: 137 additions & 0 deletions lib/Clustering/DualTreeBall.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
<?php
/*
* Copyright (c) 2023 The Recognize contributors.
* This file is licensed under the Affero General Public License version 3 or later. See the COPYING file.
*/

namespace OCA\Recognize\Clustering;

use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Graph\Nodes\Ball;
use Rubix\ML\Helpers\Stats;
use Rubix\ML\Kernels\Distance\Distance;
use function Rubix\ML\argmax;

class DualTreeBall extends Ball {
protected float $longestDistanceInNode = INF;
protected bool $fullyConnected = false;
protected $setId;


public function setLongestDistance($longestDistance): void {
if ($longestDistance < $this->longestDistanceInNode) {
$this->longestDistanceInNode = $longestDistance;
}
}

public function getLongestDistance(): float {
return $this->longestDistanceInNode;
}

public function resetLongestEdge(): void {
$this->longestDistanceInNode = INF;
foreach ($this->children() as $child) {
$child->resetLongestEdge();
}
}

public function resetFullyConnectedStatus(): void {
$this->fullyConnected = false;
foreach ($this->children() as $child) {
$child->resetFullyConnectedStatus();
}
}

/**
* @return null|int|string
*/
public function getSetId() {
if (!$this->fullyConnected) {
return null;
}
return $this->setId;
}

public function isFullyConnected(): bool {
return $this->fullyConnected;
}

/**
* @param array $labelToSetId
* @return null|int|string
*/
public function propagateSetChanges(array &$labelToSetId) {
if ($this->fullyConnected) {
// If we are already fully connected, we just need to check if the
// set id has changed
foreach ($this->children() as $child) {
$this->setId = $child->propagateSetChanges($labelToSetId);
}

return $this->setId;
}

// If, and only if, both children are fully connected and in the same set id then
// we, too, are fully connected
$setId = null;
foreach ($this->children() as $child) {
$retVal = $child->propagateSetChanges($labelToSetId);

if ($retVal === null) {
return null;
}

if ($setId !== null && $setId !== $retVal) {
return null;
}

$setId = $retVal;
}

$this->setId = $setId;
$this->fullyConnected = true;

return $this->setId;
}

/**
* Factory method to build a hypersphere by splitting the dataset into left and right clusters.
*
* @param \Rubix\ML\Datasets\Labeled $dataset
* @param \Rubix\ML\Kernels\Distance\Distance $kernel
* @return self
*/
public static function split(Labeled $dataset, Distance $kernel): self {
$center = [];

foreach ($dataset->features() as $column => $values) {
if ($dataset->featureType($column)->isContinuous()) {
$center[] = Stats::mean($values);
} else {
$center[] = argmax(array_count_values($values));
}
}

$distances = [];

foreach ($dataset->samples() as $sample) {
$distances[] = $kernel->compute($sample, $center);
}

$radius = max($distances) ?: 0.0;

$leftCentroid = $dataset->sample(argmax($distances));

$distances = [];

foreach ($dataset->samples() as $sample) {
$distances[] = $kernel->compute($sample, $leftCentroid);
}

$rightCentroid = $dataset->sample(argmax($distances));

$subsets = $dataset->spatialSplit($leftCentroid, $rightCentroid, $kernel);

return new self($center, $radius, $subsets);
}
}
106 changes: 106 additions & 0 deletions lib/Clustering/DualTreeClique.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
<?php
/*
* Copyright (c) 2023 The Recognize contributors.
* This file is licensed under the Affero General Public License version 3 or later. See the COPYING file.
*/

namespace OCA\Recognize\Clustering;

use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Graph\Nodes\Clique;
use Rubix\ML\Helpers\Stats;
use Rubix\ML\Kernels\Distance\Distance;
use function Rubix\ML\argmax;

class DualTreeClique extends Clique {
protected float $longestDistanceInNode = INF;
protected bool $fullyConnected = false;
/**
* @var null|int|string
*/
protected $setId;

public function setLongestDistance(float $longestDistance): void {
$this->longestDistanceInNode = $longestDistance;
}

public function getLongestDistance(): float {
return $this->longestDistanceInNode;
}

public function resetLongestEdge(): void {
$this->longestDistanceInNode = INF;
}

public function resetFullyConnectedStatus(): void {
$this->fullyConnected = false;
}

/**
* @return int|string|null
*/
public function getSetId() {
if (!$this->fullyConnected) {
return null;
}

return $this->setId;
}

public function isFullyConnected(): bool {
return $this->fullyConnected;
}

/**
* @param array<int|string,int|string> $labelToSetId
* @return int|mixed|string|null
*/
public function propagateSetChanges(array &$labelToSetId) {
if ($this->fullyConnected) {
$this->setId = $labelToSetId[$this->dataset->label(0)];
return $this->setId;
}

$labels = $this->dataset->labels();

$setId = $labelToSetId[array_pop($labels)];

foreach ($labels as $label) {
if ($setId !== $labelToSetId[$label]) {
return null;
}
}

$this->fullyConnected = true;
$this->setId = $setId;

return $this->setId;
}

/**
* Terminate a branch with a dataset.
*
* @param \Rubix\ML\Datasets\Labeled $dataset
* @param \Rubix\ML\Kernels\Distance\Distance $kernel
* @return self
*/
public static function terminate(Labeled $dataset, Distance $kernel): self {
$center = [];

foreach ($dataset->features() as $column => $values) {
if ($dataset->featureType($column)->isContinuous()) {
$center[] = Stats::mean($values);
} else {
$center[] = argmax(array_count_values($values));
}
}

$distances = [];

foreach ($dataset->samples() as $sample) {
$distances[] = $kernel->compute($sample, $center);
}
$radius = max($distances) ?: 0.0;
return new self($dataset, $center, $radius);
}
}
Loading

0 comments on commit 2e52fcd

Please sign in to comment.