From 00b2a1b555bb75e9931c4b3730fb7384c283d670 Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Sat, 7 Jan 2023 16:16:39 +0100 Subject: [PATCH 01/17] Implement hdbscan Signed-off-by: Marcel Klehr --- lib/Service/FaceClusterAnalyzer.php | 636 +++++++++++++++++++++++----- 1 file changed, 539 insertions(+), 97 deletions(-) diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index 29651769..dc8cb8f0 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -10,22 +10,428 @@ use OCA\Recognize\Db\FaceClusterMapper; use OCA\Recognize\Db\FaceDetection; use OCA\Recognize\Db\FaceDetectionMapper; -use Rubix\ML\Clusterers\DBSCAN; -use Rubix\ML\Datasets\Unlabeled; +use Rubix\ML\Datasets\Labeled; use Rubix\ML\Graph\Trees\BallTree; use Rubix\ML\Kernels\Distance\Euclidean; +use Rubix\ML\Kernels\Distance\Distance; +use Rubix\ML\Graph\Nodes\Hypersphere; +use Rubix\ML\Graph\Nodes\Clique; +use Rubix\ML\Graph\Nodes\Ball; +use Rubix\ML\Exceptions\InvalidArgumentException; +use SplObjectStorage; + +class MrdBallTree extends BallTree { + private array $coreDistances; + private int $coreDistSampleSize; + private Labeled $dataset; + + /** + * @param int $maxLeafSize + * @param int $coreDistSampleSize + * @param \Rubix\ML\Kernels\Distance\Distance|null $kernel + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + */ + public function __construct(int $maxLeafSize = 30, int $coreDistSampleSize = 30, ? Distance $kernel = null) { + if ($maxLeafSize < 1) { + throw new InvalidArgumentException('At least one sample is required' + . " to form a leaf node, $maxLeafSize given."); + } + + if ($coreDistSampleSize < 2) { + throw new InvalidArgumentException('At least two samples are required' + . " to calculate core distance, $coreDistSampleSize given."); + } + + $this->maxLeafSize = $maxLeafSize; + $this->coreDistSampleSize = $coreDistSampleSize; + $this->kernel = $kernel ?? new Euclidean(); + } + + /** + * Run a k nearest neighbors search excluding the provided group of samples and return the labels, and distances for the nn in a tuple. + * + * This is essentially a stripped down version of nearest() in the parent class provided by Rubix. + * + * TODO: Implement last minimum distance filter for the hyperspheres to accelerate the search. + * Hyperspheres with centroid distance+radius smaller than the previous nearest neighbor cannot contain + * the next nearest neighbor. + * + * @internal + * + * @param int $sampleKey + * @param list $sample + * @param list $groupLabels + * @param int $k + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + * @return array{list,list} + */ + public function nearestNotInGroup(int $sampleKey, array $sample, array $groupLabels, int $k = 1): array { + $visited = new SplObjectStorage(); + + $stack = $this->path($sample); + + /*$samples =*/$labels = $distances = []; + + while ($current = array_pop($stack)) { + if ($current instanceof Ball) { + $radius = $distances[$k - 1] ?? INF; + + foreach ($current->children() as $child) { + if (!$visited->contains($child)) { + if ($child instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $child->center()); + + if ($distance - $child->radius() < $radius) { + $stack[] = $child; + + continue; + } + } + + $visited->attach($child); + } + } + + $visited->attach($current); + + continue; + } + + if ($current instanceof Clique) { + $dataset = $current->dataset(); + $neighborLabels = $dataset->labels(); + + foreach ($dataset->samples() as $key => $neighbor) { + if (!in_array($neighborLabels[$key], $groupLabels)) { + $distances[] = $this->computeMrd($sampleKey, $sample, $neighborLabels[$key], $neighbor); + $labels[] = $neighborLabels[$key]; + } + } + + //$samples = array_merge($samples, $dataset->samples()); + //$labels = array_merge($labels, $dataset->labels()); + + array_multisort($distances, /*$samples,*/$labels); + + if (count($labels) > $k) { + //$samples = array_slice($samples, 0, $k); + $labels = array_slice($labels, 0, $k); + $distances = array_slice($distances, 0, $k); + } + + $visited->attach($current); + } + } + + return [ /*$samples,*/$labels, $distances]; + } + + private function getCoreDistance(int $index): float { + if (!isset($this->coreDistances[$index])) { + [$_1, $_2, $distances] = $this->nearest($this->dataset->sample($index), $this->coreDistSampleSize); + $this->coreDistances[$index] = max($distances); + } + + return $this->coreDistances[$index]; + } + + /** + * Compute the mutual reachability distance between two vectors. + * + * @internal + * + * @param int $a + * @param array $a_vector + * @param int $b + * @param array $b_vector + * @return float + */ + private function computeMrd(int $a, array $a_vector, int $b, array $b_vector): float { + $distance = $this->kernel->compute($a_vector, $b_vector); + + return max($distance, $this->getCoreDistance($a), $this->getCoreDistance($b)); + } + + /** + * Insert a root node and recursively split the dataset until a terminating + * condition is met. + * + * @internal + * + * @param \Rubix\ML\Datasets\Labeled $dataset + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + */ + public function grow(Labeled $dataset): void { + $this->dataset = $dataset; + $this->root = Ball::split($dataset, $this->kernel); + + $stack = [$this->root]; + + while ($current = array_pop($stack)) { + [$left, $right] = $current->subsets(); + + $current->cleanup(); + + if ($left->numSamples() > $this->maxLeafSize) { + $node = Ball::split($left, $this->kernel); + + $current->attachLeft($node); + + $stack[] = $node; + } elseif (!$left->empty()) { + $current->attachLeft(Clique::terminate($left, $this->kernel)); + } + + if ($right->numSamples() > $this->maxLeafSize) { + $node = Ball::split($right, $this->kernel); + + if ($node->isPoint()) { + $current->attachRight(Clique::terminate($right, $this->kernel)); + } else { + $current->attachRight($node); + + $stack[] = $node; + } + } elseif (!$right->empty()) { + $current->attachRight(Clique::terminate($right, $this->kernel)); + } + } + } +} + +class MstFaceCluster { + private array $edges; + private array $remainingEdges; + private float $startingLambda; + private float $finalLambda; + private float $clusterWeight; + private int $minimumClusterSize; + private array $coreEdges; + private bool $isRoot; + private float $maxEdgeLength; + private float $minClusterSeparation; + + + + + public function __construct(array $edges, int $minimumClusterSize, ?float $startingLambda = null, float $maxEdgeLength = 0.5, float $minClusterSeparation = 0.1) { + //Ascending sort of edges while perserving original keys. + $this->edges = $edges; + + uasort($this->edges, function ($a, $b) { + if ($a[1] > $b[1]) { + return 1; + } + if ($a[1] < $b[1]) { + return -1; + } + return 0; + }); + + $this->remainingEdges = $this->edges; + + if (is_null($startingLambda)) { + $this->isRoot = true; + $this->startingLambda = 0.0; + } else { + $this->isRoot = false; + $this->startingLambda = $startingLambda; + } + + $this->minimumClusterSize = $minimumClusterSize; + + $this->coreEdges = []; + + $this->clusterWeight = 0.0; + + $this->maxEdgeLength = $maxEdgeLength; + $this->minClusterSeparation = $minClusterSeparation; + } + + public function processCluster(): array { + $currentLambda = $lastLambda = $this->startingLambda; + $edgeLength = INF; + + while (true) { + $edgeCount = count($this->remainingEdges); + + if ($edgeCount < ($this->minimumClusterSize - 1)) { + if ($edgeLength > $this->maxEdgeLength) { + return []; + } + + $this->finalLambda = $currentLambda; + $this->coreEdges = $this->remainingEdges; + + return [$this]; + } + + $vertexConnectedTo = array_key_last($this->remainingEdges); + $currentLongestEdge = array_pop($this->remainingEdges); + $vertexConnectedFrom = $currentLongestEdge[0]; + + $edgeLength = $currentLongestEdge[1]; + + if ($edgeLength > $this->maxEdgeLength) { + // Prevent formation of clusters with edges longer than the maximum edge length + $currentLambda = $lastLambda = 1 / $edgeLength; + } elseif ($edgeLength > 0.0) { + $currentLambda = 1 / $edgeLength; + } + + $this->clusterWeight += ($currentLambda - $lastLambda) * $edgeCount; + $lastLambda = $currentLambda; + + if (!$this->pruneFromCluster($vertexConnectedTo) && !$this->pruneFromCluster($vertexConnectedFrom)) { + // This cluster will (probably) split into two child clusters: + + $childClusterEdges1 = $this->getChildClusterEdges($vertexConnectedTo); + $childClusterEdges2 = $this->getChildClusterEdges($vertexConnectedFrom); + + if ($edgeLength < $this->minClusterSeparation) { + $this->remainingEdges = count($childClusterEdges1) > count($childClusterEdges2) ? $childClusterEdges1 : $childClusterEdges2; + continue; + } + + // Choose clusters using excess of mass method: + // Return a list of children if the weight of all children is more than $this->clusterWeight. + // Otherwise return the current cluster and discard the children. This way we "choose" a combination + // of cluster that has weighs the most (i.e. has most excess of mass). Always discard the root cluster. + $this->finalLambda = $currentLambda; + + $childCluster1 = new MstFaceCluster($childClusterEdges1, $this->minimumClusterSize, $this->finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); + $childCluster2 = new MstFaceCluster($childClusterEdges2, $this->minimumClusterSize, $this->finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); + + // Resolve all chosen child clusters recursively + $childClusters = array_merge($childCluster1->processCluster(), $childCluster2->processCluster()); + + $childrenWeight = 0.0; + foreach ($childClusters as $childCluster) { + $childrenWeight += $childCluster->getClusterWeight(); + array_merge($this->coreEdges, $childCluster->getCoreEdges()); + } + + if (($childrenWeight > $this->clusterWeight) || $this->isRoot) { + return $childClusters; + } + + return [$this]; + } + + if ($edgeLength > $this->maxEdgeLength) { + $this->edges = $this->remainingEdges; + } + } + } + + private function pruneFromCluster(int $vertexId): bool { + $edgeIndicesToPrune = []; + $vertexStack = [$vertexId]; + + while (!empty($vertexStack)) { + $currentVertex = array_pop($vertexStack); + + if (count($edgeIndicesToPrune) >= ($this->minimumClusterSize - 1)) { + return false; + } + + // Traverse the MST edges backward + if (isset($this->remainingEdges[$currentVertex]) && !in_array($currentVertex, $edgeIndicesToPrune)) { + $incomingEdge = $this->remainingEdges[$currentVertex]; + $edgeIndicesToPrune[] = $currentVertex; + + $vertexStack[] = $incomingEdge[0]; + } + + // Traverse the MST edges forward + foreach ($this->remainingEdges as $key => $edge) { + if (($edge[0] == $currentVertex) && !in_array($key, $edgeIndicesToPrune)) { + $vertexStack[] = $key; + $edgeIndicesToPrune[] = $key; + } + } + } + + // Prune edges + foreach ($edgeIndicesToPrune as $edgeToPrune) { + unset($this->remainingEdges[$edgeToPrune]); + } + + return true; + } + + private function getChildClusterEdges(int $vertexId): array { + $vertexStack = [$vertexId]; + $edgesInCluster = []; + + while (!empty($vertexStack)) { + $currentVertex = array_pop($vertexStack); + + // Traverse the MST edges backward + if (isset($this->remainingEdges[$currentVertex]) && !isset($edgesInCluster[$currentVertex])) { + $incomingEdge = $this->remainingEdges[$currentVertex]; + + //Edges are indexed by the vertex they're connected to + $edgesInCluster[$currentVertex] = $incomingEdge; + + $vertexStack[] = $incomingEdge[0]; + } + + // Traverse the MST edges forward + foreach ($this->remainingEdges as $key => $edge) { + if ($edge[0] == $currentVertex && !isset($edgesInCluster[$key])) { + $vertexStack[] = $key; + $edgesInCluster[$key] = $edge; + } + } + } + + return $edgesInCluster; + } + + public function getClusterWeight(): float { + return $this->clusterWeight; + } + + public function getVertexKeys(): array { + $vertexKeys = []; + + foreach ($this->edges as $key => $edge) { + $vertexKeys[] = $key; + $vertexKeys[] = $edge[0]; + } + + return array_unique($vertexKeys); + } + + public function getCoreEdges(): array { + return $this->coreEdges; + } +} + class FaceClusterAnalyzer { - public const MIN_CLUSTER_DENSITY = 6; - public const MAX_INNER_CLUSTER_RADIUS = 0.44; + public const MIN_SAMPLE_SIZE = 4; // Conservative value: 10 + public const MIN_CLUSTER_SIZE = 6; // Conservative value: 10 + public const MAX_CLUSTER_EDGE_LENGHT = 99.0; + public const MIN_CLUSTER_SEPARATION = 0.0; public const DIMENSIONS = 128; - public const MIN_DETECTION_SIZE = 0.09; + public const ROOT_VERTEX = null; private FaceDetectionMapper $faceDetections; private FaceClusterMapper $faceClusters; private TagManager $tagManager; private Logger $logger; + private array $connectedVertices; + private array $distanceHeap; + private array $staleEdgeHeap; + private array $edges; + private float $shortestStaleDistance; + private float $shortestDistance; + private Labeled $dataset; + private MrdBallTree $detectionsTree; + public function __construct(FaceDetectionMapper $faceDetections, FaceClusterMapper $faceClusters, TagManager $tagManager, Logger $logger) { $this->faceDetections = $faceDetections; $this->faceClusters = $faceClusters; @@ -33,77 +439,142 @@ public function __construct(FaceDetectionMapper $faceDetections, FaceClusterMapp $this->logger = $logger; } + private function updateNnForVertex($vertexId): void { + if (is_null($vertexId)) { + return; + } + + [$nearestLabel, $nearestDistance] = $this->detectionsTree->nearestNotInGroup($vertexId, $this->dataset->sample($vertexId), $this->connectedVertices); + + // Two possibilities here: First, it's possibe that the distance to the nearest neighbor is less than + // the previous distance established for this vertex. Then we can remove this + // stale key from the $staleEdgeHeap. The second possibility is that the new distance is not the shortest + // available to the nearest neighbor in which case we just push the current $staleDetectionKey + // back into stale heap with the new longer distance. + if (($this->distanceHeap[$nearestLabel[0]] ?? INF) > $nearestDistance[0]) { + $this->distanceHeap[$nearestLabel[0]] = $nearestDistance[0]; + // If the nearest neighbor vertex already had an edge connected to it + // (with a longer distance) set the existing edge as stale. + if (isset($this->edges[$nearestLabel[0]])) { + $this->staleEdgeHeap[] = $this->edges[$nearestLabel[0]]; + } + + $this->edges[$nearestLabel[0]] = [$vertexId, $nearestDistance[0]]; + arsort($this->distanceHeap); + $this->shortestDistance = end($this->distanceHeap); + } else { + $this->staleEdgeHeap[] = [$vertexId, $nearestDistance[0]]; + } + + $distanceColumn = array_column($this->staleEdgeHeap, 1); + array_multisort($distanceColumn, SORT_DESC, $this->staleEdgeHeap); + $this->shortestStaleDistance = end($this->staleEdgeHeap)[1]; + + return; + } + + + /** * @throws \OCP\DB\Exception * @throws \JsonException */ public function calculateClusters(string $userId): void { - $this->logger->debug('Find face detection for use '.$userId); - $detections = $this->faceDetections->findByUserId($userId); + $this->logger->debug('ClusterDebug: Retrieving face detections for user ' . $userId); - $detections = array_values(array_filter($detections, fn ($detection) => - $detection->getHeight() > self::MIN_DETECTION_SIZE && $detection->getWidth() > self::MIN_DETECTION_SIZE - )); + $detections = $this->faceDetections->findByUserId($userId); - if (count($detections) === 0) { - $this->logger->debug('No face detections found'); + if (count($detections) < max(self::MIN_SAMPLE_SIZE, self::MIN_CLUSTER_SIZE)) { + $this->logger->debug('ClusterDebug: Not enough face detections found'); return; } - $unclusteredDetections = $this->assignToExistingClusters($userId, $detections); + $this->logger->debug('ClusterDebug: Found ' . count($detections) . " detections. Calculating clusters."); - if (count($unclusteredDetections) === 0) { - $this->logger->debug('No unclustered face detections left after incremental run'); - return; + $this->dataset = new Labeled(array_map(function (FaceDetection $detection): array { + return $detection->getVector(); + }, $detections), array_combine(array_keys($detections), array_keys($detections)), false); + + $this->detectionsTree = new MrdBallTree(10, self::MIN_SAMPLE_SIZE, new Euclidean()); + $this->detectionsTree->grow($this->dataset); + + // A quick and dirty Prim's algorithm: + //TODO: Slight performance increase could perhaps be gained by replacing arsort/array_multisort in the $this->updateNnForVertex with a function that would + // insert all new distances into the corresponding arrays perserving the descending order. + //TODO: MrdBallTree->nearestNotInGroup requires optimization (see definition of MrdBallTree) + + $this->connectedVertices = []; + $this->distanceHeap = []; //array_fill_keys(array_keys($detections), INF); + $this->distanceHeap[array_key_first($detections)] = 0; // [*distance*,] + $this->shortestDistance = 0.0; + + // Updating nearest neighbor distance for all points is unnecessary, so we + // keep track of "stale" nearest neighbor distances. These stale distances + // will only be updated if the current shortest distance in $distanceHeap exceeds + // the shortest stale distance. + $this->staleEdgeHeap = []; // Will contain tuples of detection keys and corresponding (stale) nearest neighbor distances. + $this->shortestStaleDistance = INF; + + // Key values of $edges[] will correspond to the vertex the edge connects to while the array at each row + // will be a tuple containing the vertex connected from and the connection cost/distance + $this->edges[array_key_first($detections)] = [self::ROOT_VERTEX, INF]; + //$this->edges = []; + + $numVertices = count($this->dataset->labels()) - 1; // No need to loop through the last vertex. + + while (count($this->connectedVertices) < $numVertices) { + // If necessary, update the distances in the stale heap + while ($this->shortestStaleDistance < $this->shortestDistance) { + $staleDetectionKey = array_pop($this->staleEdgeHeap)[0]; + $this->updateNnForVertex($staleDetectionKey); + } + + // Get the next edge with the smallest cost + $addedVertex = array_key_last($this->distanceHeap); // Technically it'd be equivalent to do key($distanceHeap) here... + unset($this->distanceHeap[$addedVertex]); + $this->connectedVertices[] = $addedVertex; + + $this->staleEdgeHeap[] = $this->edges[$addedVertex]; + $this->updateNnForVertex($addedVertex); } - // Here we use RubixMLs DBSCAN clustering algorithm - $dataset = new Unlabeled(array_map(function (FaceDetection $detection) : array { - return $detection->getVector(); - }, $unclusteredDetections)); + unset($this->edges[array_key_first($detections)]); - $clusterer = new DBSCAN(self::MAX_INNER_CLUSTER_RADIUS, self::MIN_CLUSTER_DENSITY, new BallTree(100, new Euclidean())); - $this->logger->debug('Calculate clusters for '.count($unclusteredDetections).' faces'); - $results = $clusterer->predict($dataset); - $numClusters = max($results); + $mstClusterer = new MstFaceCluster($this->edges, self::MIN_CLUSTER_SIZE, null, self::MAX_CLUSTER_EDGE_LENGHT, self::MIN_CLUSTER_SEPARATION); + $flatClusters = $mstClusterer->processCluster(); - $this->logger->debug('Found '.$numClusters.' new face clusters'); + // Write clusters to db + // TODO: For now just discard all previous clusters. + if (count($this->faceClusters->findByUserId($userId)) > 0) { + $this->faceClusters->deleteAll(); + } - for ($i = 0; $i <= $numClusters; $i++) { - $keys = array_keys($results, $i); - $clusterDetections = array_map(function ($key) use ($unclusteredDetections) : FaceDetection { - return $unclusteredDetections[$key]; - }, $keys); + $numberOfClusteredDetections = 0; + foreach ($flatClusters as $flatCluster) { $cluster = new FaceCluster(); $cluster->setTitle(''); $cluster->setUserId($userId); $this->faceClusters->insert($cluster); - $clusterCentroid = self::calculateCentroidOfDetections($clusterDetections); + $detectionKeys = $flatCluster->getVertexKeys(); - foreach ($clusterDetections as $detection) { - // If threshold is larger than 0 and $clusterCentroid is not the null vector - if ($detection->getThreshold() > 0.0 && count(array_filter($clusterCentroid, fn ($el) => $el !== 0.0)) > 0) { - // If a threshold is set for this detection and its vector is farther away from the centroid - // than the threshold, skip assigning this detection to the cluster - $distanceValue = self::distance($clusterCentroid, $detection->getVector()); - if ($distanceValue >= $detection->getThreshold()) { - continue; - } - } - - $this->faceDetections->assocWithCluster($detection, $cluster); + foreach ($detectionKeys as $detectionKey) { + $this->faceDetections->assocWithCluster($detections[$detectionKey], $cluster); + $numberOfClusteredDetections += 1; } } - $this->pruneDuplicateFilesFromClusters($userId); + $this->logger->debug('ClusterDebug: Clustering complete. Total num of clustered detections: ' . $numberOfClusteredDetections); + $this->pruneClusters($userId); } + + /** * @throws \OCP\DB\Exception */ - public function pruneDuplicateFilesFromClusters(string $userId): void { + public function pruneClusters(string $userId): void { $clusters = $this->faceClusters->findByUserId($userId); if (count($clusters) === 0) { @@ -124,7 +595,8 @@ public function pruneDuplicateFilesFromClusters(string $userId): void { foreach ($filesWithDuplicateFaces as $fileDetections) { $detectionsByDistance = []; foreach ($fileDetections as $detection) { - $detectionsByDistance[$detection->getId()] = self::distance($centroid, $detection->getVector()); + $distance = new Euclidean(); + $detectionsByDistance[$detection->getId()] = $distance->compute($centroid, $detection->getVector()); } asort($detectionsByDistance); $bestMatchingDetectionId = array_keys($detectionsByDistance)[0]; @@ -191,67 +663,37 @@ private function findFilesWithDuplicateFaces(array $detections): array { } /** - * @param string $userId - * @param list $detections - * @return list * @throws \OCP\DB\Exception + * @param array $clusterIdsToMerge + * @param int $parentClusterId + * @return void */ - private function assignToExistingClusters(string $userId, array $detections): array { - $clusters = $this->faceClusters->findByUserId($userId); + private function mergeClusters(array $clusterIdsToMerge, int $parentClusterId): void { + $clusterIdsToMerge = array_unique($clusterIdsToMerge); + foreach ($clusterIdsToMerge as $childClusterId) { + if ($childClusterId == $parentClusterId) { + continue; + } - if (count($clusters) === 0) { - return $detections; - } + $detections = $this->faceDetections->findByClusterId($childClusterId); + $parentCluster = $this->faceClusters->find($parentClusterId); - $unclusteredDetections = []; - foreach ($detections as $detection) { - $bestCluster = null; - $bestClusterDistance = 999; - if ($detection->getClusterId() !== null) { + try { + $childCluster = $this->faceClusters->find($childClusterId); + } catch (\Exception $e) { + $this->logger->debug('ExtraDebug: Child cluster already deleted: ' . $childClusterId); continue; } - foreach ($clusters as $cluster) { - $clusterDetections = $this->faceDetections->findByClusterId($cluster->getId()); - if (count($clusterDetections) > 50) { - $clusterDetections = array_map(fn ($key) => $clusterDetections[$key], array_rand($clusterDetections, 50)); - } - $clusterCentroid = self::calculateCentroidOfDetections($clusterDetections); - if ($detection->getThreshold() > 0 && self::distance($clusterCentroid, $detection->getVector()) >= $detection->getThreshold()) { - continue; - } - foreach ($clusterDetections as $clusterDetection) { - $distance = self::distance($clusterDetection->getVector(), $detection->getVector()); - if ( - $distance <= self::MAX_INNER_CLUSTER_RADIUS - && (!isset($bestCluster) || $distance < $bestClusterDistance) - ) { - $bestCluster = $cluster; - $bestClusterDistance = self::distance($clusterDetection->getVector(), $detection->getVector()); - break; - } - } - } - if ($bestCluster !== null) { - $this->faceDetections->assocWithCluster($detection, $bestCluster); - continue; + + foreach ($detections as $detection) { + $this->faceDetections->assocWithCluster($detection, $parentCluster); } - $unclusteredDetections[] = $detection; + + $this->faceClusters->delete($childCluster); } - return $unclusteredDetections; - } - private static ?Euclidean $distance; - /** - * @param list $v1 - * @param list $v2 - * @return float - */ - private static function distance(array $v1, array $v2): float { - if (!isset(self::$distance)) { - self::$distance = new Euclidean(); - } - return self::$distance->compute($v1, $v2); + return; } } From 2f5c07fc9fc35bdd9a3ee30ed41f0a1064e45e3f Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Mon, 9 Jan 2023 13:28:00 +0100 Subject: [PATCH 02/17] Improve HDBSCAN implementation + Use incremental processing after clustering Signed-off-by: Marcel Klehr --- lib/Clustering/MRDistance.php | 46 ++ lib/Clustering/MstClusterer.php | 215 ++++++++++ lib/Service/FaceClusterAnalyzer.php | 630 ++++++---------------------- 3 files changed, 378 insertions(+), 513 deletions(-) create mode 100644 lib/Clustering/MRDistance.php create mode 100644 lib/Clustering/MstClusterer.php diff --git a/lib/Clustering/MRDistance.php b/lib/Clustering/MRDistance.php new file mode 100644 index 00000000..17338334 --- /dev/null +++ b/lib/Clustering/MRDistance.php @@ -0,0 +1,46 @@ +coreDistSampleSize = $coreDistSampleSize; + $this->kernel = $kernel; + $this->coreDistances = []; + $this->dataset = $dataset; + + $this->distanceTree = new BallTree($coreDistSampleSize * 3, $kernel); + $this->distanceTree->grow($dataset); + + $this->kernel = $kernel; + } + + public function distance(int $a, array $aVector, int $b, array $bVector): float { + $distance = $this->kernel->compute($aVector, $bVector); + + return max($distance, $this->getCoreDistance($a), $this->getCoreDistance($b)); + } + + private function getCoreDistance(int $index): float { + if (!isset($this->coreDistances[$index])) { + [$_1, $_2, $distances] = $this->distanceTree->nearest($this->dataset->sample($index), $this->coreDistSampleSize); + $this->coreDistances[$index] = end($distances); + } + + return $this->coreDistances[$index]; + } +} diff --git a/lib/Clustering/MstClusterer.php b/lib/Clustering/MstClusterer.php new file mode 100644 index 00000000..a352e821 --- /dev/null +++ b/lib/Clustering/MstClusterer.php @@ -0,0 +1,215 @@ +edges = $edges; + + uasort($this->edges, function ($a, $b) { + if ($a[1] > $b[1]) { + return 1; + } + if ($a[1] < $b[1]) { + return -1; + } + return 0; + }); + + $this->remainingEdges = $this->edges; + + if (is_null($startingLambda)) { + $this->isRoot = true; + $this->startingLambda = 0.0; + } else { + $this->isRoot = false; + $this->startingLambda = $startingLambda; + } + + $this->minimumClusterSize = $minimumClusterSize; + + $this->coreEdges = []; + + $this->clusterWeight = 0.0; + + $this->maxEdgeLength = $maxEdgeLength; + $this->minClusterSeparation = $minClusterSeparation; + } + + public function processCluster(): array { + $currentLambda = $lastLambda = $this->startingLambda; + $edgeLength = INF; + + while (true) { + $edgeCount = count($this->remainingEdges); + + if ($edgeCount < ($this->minimumClusterSize - 1)) { + if ($edgeLength > $this->maxEdgeLength) { + return []; + } + + $this->finalLambda = $currentLambda; + $this->coreEdges = $this->remainingEdges; + + return [$this]; + } + + $vertexConnectedTo = array_key_last($this->remainingEdges); + $currentLongestEdge = array_pop($this->remainingEdges); + $vertexConnectedFrom = $currentLongestEdge[0]; + + $edgeLength = $currentLongestEdge[1]; + + if ($edgeLength > $this->maxEdgeLength) { + // Prevent formation of clusters with edges longer than the maximum edge length + $currentLambda = $lastLambda = 1 / $edgeLength; + } elseif ($edgeLength > 0.0) { + $currentLambda = 1 / $edgeLength; + } + + $this->clusterWeight += ($currentLambda - $lastLambda) * $edgeCount; + $lastLambda = $currentLambda; + + if (!$this->pruneFromCluster($vertexConnectedTo) && !$this->pruneFromCluster($vertexConnectedFrom)) { + // This cluster will (probably) split into two child clusters: + + $childClusterEdges1 = $this->getChildClusterEdges($vertexConnectedTo); + $childClusterEdges2 = $this->getChildClusterEdges($vertexConnectedFrom); + + if ($edgeLength < $this->minClusterSeparation) { + $this->remainingEdges = count($childClusterEdges1) > count($childClusterEdges2) ? $childClusterEdges1 : $childClusterEdges2; + continue; + } + + // Choose clusters using excess of mass method: + // Return a list of children if the weight of all children is more than $this->clusterWeight. + // Otherwise return the current cluster and discard the children. This way we "choose" a combination + // of clusters that weigh the most (i.e. have most (excess of) mass). Always discard the root cluster. + $this->finalLambda = $currentLambda; + + $childCluster1 = new MstClusterer($childClusterEdges1, $this->minimumClusterSize, $this->finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); + $childCluster2 = new MstClusterer($childClusterEdges2, $this->minimumClusterSize, $this->finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); + + // Resolve all chosen child clusters recursively + $childClusters = array_merge($childCluster1->processCluster(), $childCluster2->processCluster()); + + $childrenWeight = 0.0; + foreach ($childClusters as $childCluster) { + $childrenWeight += $childCluster->getClusterWeight(); + array_merge($this->coreEdges, $childCluster->getCoreEdges()); + } + + if (($childrenWeight > $this->clusterWeight) || $this->isRoot) { + return $childClusters; + } + + return [$this]; + } + + if ($edgeLength > $this->maxEdgeLength) { + $this->edges = $this->remainingEdges; + } + } + } + + private function pruneFromCluster(int $vertexId): bool { + $edgeIndicesToPrune = []; + $vertexStack = [$vertexId]; + + while (!empty($vertexStack)) { + $currentVertex = array_pop($vertexStack); + + if (count($edgeIndicesToPrune) >= ($this->minimumClusterSize - 1)) { + return false; + } + + // Traverse the MST edges backward + if (isset($this->remainingEdges[$currentVertex]) && !in_array($currentVertex, $edgeIndicesToPrune)) { + $incomingEdge = $this->remainingEdges[$currentVertex]; + $edgeIndicesToPrune[] = $currentVertex; + + $vertexStack[] = $incomingEdge[0]; + } + + // Traverse the MST edges forward + foreach ($this->remainingEdges as $key => $edge) { + if (($edge[0] == $currentVertex) && !in_array($key, $edgeIndicesToPrune)) { + $vertexStack[] = $key; + $edgeIndicesToPrune[] = $key; + } + } + } + + // Prune edges + foreach ($edgeIndicesToPrune as $edgeToPrune) { + unset($this->remainingEdges[$edgeToPrune]); + } + + return true; + } + + private function getChildClusterEdges(int $vertexId): array { + $vertexStack = [$vertexId]; + $edgesInCluster = []; + + while (!empty($vertexStack)) { + $currentVertex = array_pop($vertexStack); + + // Traverse the MST edges backward + if (isset($this->remainingEdges[$currentVertex]) && !isset($edgesInCluster[$currentVertex])) { + $incomingEdge = $this->remainingEdges[$currentVertex]; + + //Edges are indexed by the vertex they're connected to + $edgesInCluster[$currentVertex] = $incomingEdge; + + $vertexStack[] = $incomingEdge[0]; + } + + // Traverse the MST edges forward + foreach ($this->remainingEdges as $key => $edge) { + if ($edge[0] == $currentVertex && !isset($edgesInCluster[$key])) { + $vertexStack[] = $key; + $edgesInCluster[$key] = $edge; + } + } + } + + return $edgesInCluster; + } + + public function getClusterWeight(): float { + return $this->clusterWeight; + } + + public function getVertexKeys(): array { + $vertexKeys = []; + + foreach ($this->edges as $key => $edge) { + $vertexKeys[] = $key; + $vertexKeys[] = $edge[0]; + } + + return array_unique($vertexKeys); + } + + public function getCoreEdges(): array { + return $this->coreEdges; + } +} diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index dc8cb8f0..34b1ed90 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -6,431 +6,35 @@ namespace OCA\Recognize\Service; +use OCA\Recognize\Clustering\MRDistance; +use OCA\Recognize\Clustering\MstClusterer; use OCA\Recognize\Db\FaceCluster; use OCA\Recognize\Db\FaceClusterMapper; use OCA\Recognize\Db\FaceDetection; use OCA\Recognize\Db\FaceDetectionMapper; use Rubix\ML\Datasets\Labeled; -use Rubix\ML\Graph\Trees\BallTree; use Rubix\ML\Kernels\Distance\Euclidean; -use Rubix\ML\Kernels\Distance\Distance; -use Rubix\ML\Graph\Nodes\Hypersphere; -use Rubix\ML\Graph\Nodes\Clique; -use Rubix\ML\Graph\Nodes\Ball; -use Rubix\ML\Exceptions\InvalidArgumentException; -use SplObjectStorage; - -class MrdBallTree extends BallTree { - private array $coreDistances; - private int $coreDistSampleSize; - private Labeled $dataset; - - /** - * @param int $maxLeafSize - * @param int $coreDistSampleSize - * @param \Rubix\ML\Kernels\Distance\Distance|null $kernel - * @throws \Rubix\ML\Exceptions\InvalidArgumentException - */ - public function __construct(int $maxLeafSize = 30, int $coreDistSampleSize = 30, ? Distance $kernel = null) { - if ($maxLeafSize < 1) { - throw new InvalidArgumentException('At least one sample is required' - . " to form a leaf node, $maxLeafSize given."); - } - - if ($coreDistSampleSize < 2) { - throw new InvalidArgumentException('At least two samples are required' - . " to calculate core distance, $coreDistSampleSize given."); - } - - $this->maxLeafSize = $maxLeafSize; - $this->coreDistSampleSize = $coreDistSampleSize; - $this->kernel = $kernel ?? new Euclidean(); - } - - /** - * Run a k nearest neighbors search excluding the provided group of samples and return the labels, and distances for the nn in a tuple. - * - * This is essentially a stripped down version of nearest() in the parent class provided by Rubix. - * - * TODO: Implement last minimum distance filter for the hyperspheres to accelerate the search. - * Hyperspheres with centroid distance+radius smaller than the previous nearest neighbor cannot contain - * the next nearest neighbor. - * - * @internal - * - * @param int $sampleKey - * @param list $sample - * @param list $groupLabels - * @param int $k - * @throws \Rubix\ML\Exceptions\InvalidArgumentException - * @return array{list,list} - */ - public function nearestNotInGroup(int $sampleKey, array $sample, array $groupLabels, int $k = 1): array { - $visited = new SplObjectStorage(); - - $stack = $this->path($sample); - - /*$samples =*/$labels = $distances = []; - - while ($current = array_pop($stack)) { - if ($current instanceof Ball) { - $radius = $distances[$k - 1] ?? INF; - - foreach ($current->children() as $child) { - if (!$visited->contains($child)) { - if ($child instanceof Hypersphere) { - $distance = $this->kernel->compute($sample, $child->center()); - - if ($distance - $child->radius() < $radius) { - $stack[] = $child; - - continue; - } - } - - $visited->attach($child); - } - } - - $visited->attach($current); - - continue; - } - - if ($current instanceof Clique) { - $dataset = $current->dataset(); - $neighborLabels = $dataset->labels(); - - foreach ($dataset->samples() as $key => $neighbor) { - if (!in_array($neighborLabels[$key], $groupLabels)) { - $distances[] = $this->computeMrd($sampleKey, $sample, $neighborLabels[$key], $neighbor); - $labels[] = $neighborLabels[$key]; - } - } - - //$samples = array_merge($samples, $dataset->samples()); - //$labels = array_merge($labels, $dataset->labels()); - - array_multisort($distances, /*$samples,*/$labels); - - if (count($labels) > $k) { - //$samples = array_slice($samples, 0, $k); - $labels = array_slice($labels, 0, $k); - $distances = array_slice($distances, 0, $k); - } - - $visited->attach($current); - } - } - - return [ /*$samples,*/$labels, $distances]; - } - - private function getCoreDistance(int $index): float { - if (!isset($this->coreDistances[$index])) { - [$_1, $_2, $distances] = $this->nearest($this->dataset->sample($index), $this->coreDistSampleSize); - $this->coreDistances[$index] = max($distances); - } - - return $this->coreDistances[$index]; - } - - /** - * Compute the mutual reachability distance between two vectors. - * - * @internal - * - * @param int $a - * @param array $a_vector - * @param int $b - * @param array $b_vector - * @return float - */ - private function computeMrd(int $a, array $a_vector, int $b, array $b_vector): float { - $distance = $this->kernel->compute($a_vector, $b_vector); - - return max($distance, $this->getCoreDistance($a), $this->getCoreDistance($b)); - } - - /** - * Insert a root node and recursively split the dataset until a terminating - * condition is met. - * - * @internal - * - * @param \Rubix\ML\Datasets\Labeled $dataset - * @throws \Rubix\ML\Exceptions\InvalidArgumentException - */ - public function grow(Labeled $dataset): void { - $this->dataset = $dataset; - $this->root = Ball::split($dataset, $this->kernel); - - $stack = [$this->root]; - - while ($current = array_pop($stack)) { - [$left, $right] = $current->subsets(); - - $current->cleanup(); - - if ($left->numSamples() > $this->maxLeafSize) { - $node = Ball::split($left, $this->kernel); - - $current->attachLeft($node); - - $stack[] = $node; - } elseif (!$left->empty()) { - $current->attachLeft(Clique::terminate($left, $this->kernel)); - } - - if ($right->numSamples() > $this->maxLeafSize) { - $node = Ball::split($right, $this->kernel); - - if ($node->isPoint()) { - $current->attachRight(Clique::terminate($right, $this->kernel)); - } else { - $current->attachRight($node); - - $stack[] = $node; - } - } elseif (!$right->empty()) { - $current->attachRight(Clique::terminate($right, $this->kernel)); - } - } - } -} - -class MstFaceCluster { - private array $edges; - private array $remainingEdges; - private float $startingLambda; - private float $finalLambda; - private float $clusterWeight; - private int $minimumClusterSize; - private array $coreEdges; - private bool $isRoot; - private float $maxEdgeLength; - private float $minClusterSeparation; - - - - - public function __construct(array $edges, int $minimumClusterSize, ?float $startingLambda = null, float $maxEdgeLength = 0.5, float $minClusterSeparation = 0.1) { - //Ascending sort of edges while perserving original keys. - $this->edges = $edges; - - uasort($this->edges, function ($a, $b) { - if ($a[1] > $b[1]) { - return 1; - } - if ($a[1] < $b[1]) { - return -1; - } - return 0; - }); - - $this->remainingEdges = $this->edges; - - if (is_null($startingLambda)) { - $this->isRoot = true; - $this->startingLambda = 0.0; - } else { - $this->isRoot = false; - $this->startingLambda = $startingLambda; - } - - $this->minimumClusterSize = $minimumClusterSize; - - $this->coreEdges = []; - - $this->clusterWeight = 0.0; - - $this->maxEdgeLength = $maxEdgeLength; - $this->minClusterSeparation = $minClusterSeparation; - } - - public function processCluster(): array { - $currentLambda = $lastLambda = $this->startingLambda; - $edgeLength = INF; - - while (true) { - $edgeCount = count($this->remainingEdges); - - if ($edgeCount < ($this->minimumClusterSize - 1)) { - if ($edgeLength > $this->maxEdgeLength) { - return []; - } - - $this->finalLambda = $currentLambda; - $this->coreEdges = $this->remainingEdges; - - return [$this]; - } - - $vertexConnectedTo = array_key_last($this->remainingEdges); - $currentLongestEdge = array_pop($this->remainingEdges); - $vertexConnectedFrom = $currentLongestEdge[0]; - - $edgeLength = $currentLongestEdge[1]; - - if ($edgeLength > $this->maxEdgeLength) { - // Prevent formation of clusters with edges longer than the maximum edge length - $currentLambda = $lastLambda = 1 / $edgeLength; - } elseif ($edgeLength > 0.0) { - $currentLambda = 1 / $edgeLength; - } - - $this->clusterWeight += ($currentLambda - $lastLambda) * $edgeCount; - $lastLambda = $currentLambda; - - if (!$this->pruneFromCluster($vertexConnectedTo) && !$this->pruneFromCluster($vertexConnectedFrom)) { - // This cluster will (probably) split into two child clusters: - - $childClusterEdges1 = $this->getChildClusterEdges($vertexConnectedTo); - $childClusterEdges2 = $this->getChildClusterEdges($vertexConnectedFrom); - - if ($edgeLength < $this->minClusterSeparation) { - $this->remainingEdges = count($childClusterEdges1) > count($childClusterEdges2) ? $childClusterEdges1 : $childClusterEdges2; - continue; - } - - // Choose clusters using excess of mass method: - // Return a list of children if the weight of all children is more than $this->clusterWeight. - // Otherwise return the current cluster and discard the children. This way we "choose" a combination - // of cluster that has weighs the most (i.e. has most excess of mass). Always discard the root cluster. - $this->finalLambda = $currentLambda; - - $childCluster1 = new MstFaceCluster($childClusterEdges1, $this->minimumClusterSize, $this->finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); - $childCluster2 = new MstFaceCluster($childClusterEdges2, $this->minimumClusterSize, $this->finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); - - // Resolve all chosen child clusters recursively - $childClusters = array_merge($childCluster1->processCluster(), $childCluster2->processCluster()); - - $childrenWeight = 0.0; - foreach ($childClusters as $childCluster) { - $childrenWeight += $childCluster->getClusterWeight(); - array_merge($this->coreEdges, $childCluster->getCoreEdges()); - } - - if (($childrenWeight > $this->clusterWeight) || $this->isRoot) { - return $childClusters; - } - - return [$this]; - } - - if ($edgeLength > $this->maxEdgeLength) { - $this->edges = $this->remainingEdges; - } - } - } - - private function pruneFromCluster(int $vertexId): bool { - $edgeIndicesToPrune = []; - $vertexStack = [$vertexId]; - - while (!empty($vertexStack)) { - $currentVertex = array_pop($vertexStack); - - if (count($edgeIndicesToPrune) >= ($this->minimumClusterSize - 1)) { - return false; - } - - // Traverse the MST edges backward - if (isset($this->remainingEdges[$currentVertex]) && !in_array($currentVertex, $edgeIndicesToPrune)) { - $incomingEdge = $this->remainingEdges[$currentVertex]; - $edgeIndicesToPrune[] = $currentVertex; - - $vertexStack[] = $incomingEdge[0]; - } - - // Traverse the MST edges forward - foreach ($this->remainingEdges as $key => $edge) { - if (($edge[0] == $currentVertex) && !in_array($key, $edgeIndicesToPrune)) { - $vertexStack[] = $key; - $edgeIndicesToPrune[] = $key; - } - } - } - - // Prune edges - foreach ($edgeIndicesToPrune as $edgeToPrune) { - unset($this->remainingEdges[$edgeToPrune]); - } - - return true; - } - - private function getChildClusterEdges(int $vertexId): array { - $vertexStack = [$vertexId]; - $edgesInCluster = []; - - while (!empty($vertexStack)) { - $currentVertex = array_pop($vertexStack); - - // Traverse the MST edges backward - if (isset($this->remainingEdges[$currentVertex]) && !isset($edgesInCluster[$currentVertex])) { - $incomingEdge = $this->remainingEdges[$currentVertex]; - - //Edges are indexed by the vertex they're connected to - $edgesInCluster[$currentVertex] = $incomingEdge; - - $vertexStack[] = $incomingEdge[0]; - } - - // Traverse the MST edges forward - foreach ($this->remainingEdges as $key => $edge) { - if ($edge[0] == $currentVertex && !isset($edgesInCluster[$key])) { - $vertexStack[] = $key; - $edgesInCluster[$key] = $edge; - } - } - } - - return $edgesInCluster; - } - - public function getClusterWeight(): float { - return $this->clusterWeight; - } - - public function getVertexKeys(): array { - $vertexKeys = []; - - foreach ($this->edges as $key => $edge) { - $vertexKeys[] = $key; - $vertexKeys[] = $edge[0]; - } - - return array_unique($vertexKeys); - } - - public function getCoreEdges(): array { - return $this->coreEdges; - } -} - class FaceClusterAnalyzer { public const MIN_SAMPLE_SIZE = 4; // Conservative value: 10 public const MIN_CLUSTER_SIZE = 6; // Conservative value: 10 public const MAX_CLUSTER_EDGE_LENGHT = 99.0; public const MIN_CLUSTER_SEPARATION = 0.0; + // For incremental clustering + public const MAX_INNER_CLUSTER_RADIUS = 0.44; + public const MIN_DETECTION_SIZE = 0.03; + public const DIMENSIONS = 128; - public const ROOT_VERTEX = null; private FaceDetectionMapper $faceDetections; private FaceClusterMapper $faceClusters; private TagManager $tagManager; private Logger $logger; - private array $connectedVertices; - private array $distanceHeap; - private array $staleEdgeHeap; private array $edges; - private float $shortestStaleDistance; - private float $shortestDistance; private Labeled $dataset; - private MrdBallTree $detectionsTree; + + private MRDistance $distanceKernel; public function __construct(FaceDetectionMapper $faceDetections, FaceClusterMapper $faceClusters, TagManager $tagManager, Logger $logger) { $this->faceDetections = $faceDetections; @@ -439,40 +43,6 @@ public function __construct(FaceDetectionMapper $faceDetections, FaceClusterMapp $this->logger = $logger; } - private function updateNnForVertex($vertexId): void { - if (is_null($vertexId)) { - return; - } - - [$nearestLabel, $nearestDistance] = $this->detectionsTree->nearestNotInGroup($vertexId, $this->dataset->sample($vertexId), $this->connectedVertices); - - // Two possibilities here: First, it's possibe that the distance to the nearest neighbor is less than - // the previous distance established for this vertex. Then we can remove this - // stale key from the $staleEdgeHeap. The second possibility is that the new distance is not the shortest - // available to the nearest neighbor in which case we just push the current $staleDetectionKey - // back into stale heap with the new longer distance. - if (($this->distanceHeap[$nearestLabel[0]] ?? INF) > $nearestDistance[0]) { - $this->distanceHeap[$nearestLabel[0]] = $nearestDistance[0]; - // If the nearest neighbor vertex already had an edge connected to it - // (with a longer distance) set the existing edge as stale. - if (isset($this->edges[$nearestLabel[0]])) { - $this->staleEdgeHeap[] = $this->edges[$nearestLabel[0]]; - } - - $this->edges[$nearestLabel[0]] = [$vertexId, $nearestDistance[0]]; - arsort($this->distanceHeap); - $this->shortestDistance = end($this->distanceHeap); - } else { - $this->staleEdgeHeap[] = [$vertexId, $nearestDistance[0]]; - } - - $distanceColumn = array_column($this->staleEdgeHeap, 1); - array_multisort($distanceColumn, SORT_DESC, $this->staleEdgeHeap); - $this->shortestStaleDistance = end($this->staleEdgeHeap)[1]; - - return; - } - /** @@ -484,71 +54,71 @@ public function calculateClusters(string $userId): void { $detections = $this->faceDetections->findByUserId($userId); - if (count($detections) < max(self::MIN_SAMPLE_SIZE, self::MIN_CLUSTER_SIZE)) { + $detections = array_values(array_filter($detections, fn ($detection) => + $detection->getHeight() > self::MIN_DETECTION_SIZE && $detection->getWidth() > self::MIN_DETECTION_SIZE + )); + + $unclusteredDetections = $this->assignToExistingClusters($userId, $detections); + + if (count($unclusteredDetections) < max(self::MIN_SAMPLE_SIZE, self::MIN_CLUSTER_SIZE)) { $this->logger->debug('ClusterDebug: Not enough face detections found'); return; } - $this->logger->debug('ClusterDebug: Found ' . count($detections) . " detections. Calculating clusters."); + $this->logger->debug('ClusterDebug: Found ' . count($unclusteredDetections) . " unclustered detections. Calculating clusters."); $this->dataset = new Labeled(array_map(function (FaceDetection $detection): array { return $detection->getVector(); }, $detections), array_combine(array_keys($detections), array_keys($detections)), false); - $this->detectionsTree = new MrdBallTree(10, self::MIN_SAMPLE_SIZE, new Euclidean()); - $this->detectionsTree->grow($this->dataset); - - // A quick and dirty Prim's algorithm: - //TODO: Slight performance increase could perhaps be gained by replacing arsort/array_multisort in the $this->updateNnForVertex with a function that would - // insert all new distances into the corresponding arrays perserving the descending order. - //TODO: MrdBallTree->nearestNotInGroup requires optimization (see definition of MrdBallTree) - - $this->connectedVertices = []; - $this->distanceHeap = []; //array_fill_keys(array_keys($detections), INF); - $this->distanceHeap[array_key_first($detections)] = 0; // [*distance*,] - $this->shortestDistance = 0.0; - - // Updating nearest neighbor distance for all points is unnecessary, so we - // keep track of "stale" nearest neighbor distances. These stale distances - // will only be updated if the current shortest distance in $distanceHeap exceeds - // the shortest stale distance. - $this->staleEdgeHeap = []; // Will contain tuples of detection keys and corresponding (stale) nearest neighbor distances. - $this->shortestStaleDistance = INF; - - // Key values of $edges[] will correspond to the vertex the edge connects to while the array at each row - // will be a tuple containing the vertex connected from and the connection cost/distance - $this->edges[array_key_first($detections)] = [self::ROOT_VERTEX, INF]; - //$this->edges = []; - - $numVertices = count($this->dataset->labels()) - 1; // No need to loop through the last vertex. - - while (count($this->connectedVertices) < $numVertices) { - // If necessary, update the distances in the stale heap - while ($this->shortestStaleDistance < $this->shortestDistance) { - $staleDetectionKey = array_pop($this->staleEdgeHeap)[0]; - $this->updateNnForVertex($staleDetectionKey); - } + $this->distanceKernel = new MRDistance(self::MIN_SAMPLE_SIZE, $this->dataset, new Euclidean()); + + $primsStartTime = microtime(true);// DEBUG - // Get the next edge with the smallest cost - $addedVertex = array_key_last($this->distanceHeap); // Technically it'd be equivalent to do key($distanceHeap) here... - unset($this->distanceHeap[$addedVertex]); - $this->connectedVertices[] = $addedVertex; + // Prim's algorithm: - $this->staleEdgeHeap[] = $this->edges[$addedVertex]; - $this->updateNnForVertex($addedVertex); + $this->unconnectedVertices = array_combine(array_keys($detections), array_keys($detections)); + + $firstVertex = current($this->unconnectedVertices); + $firstVertexVector = $this->dataset->sample($firstVertex); + unset($this->unconnectedVertices[$firstVertex]); + + $this->edges = []; + foreach ($this->unconnectedVertices as $vertex) { + $this->edges[$vertex] = [$firstVertex, $this->distanceKernel->distance($firstVertex, $firstVertexVector, $vertex, $this->dataset->sample($vertex))]; } - unset($this->edges[array_key_first($detections)]); + while (count($this->unconnectedVertices) > 0) { + $minDistance = INF; + $minVertex = null; - $mstClusterer = new MstFaceCluster($this->edges, self::MIN_CLUSTER_SIZE, null, self::MAX_CLUSTER_EDGE_LENGHT, self::MIN_CLUSTER_SEPARATION); - $flatClusters = $mstClusterer->processCluster(); + foreach ($this->unconnectedVertices as $vertex) { + $distance = $this->edges[$vertex][1]; + if ($distance < $minDistance) { + $minDistance = $distance; + $minVertex = $vertex; + } + } - // Write clusters to db - // TODO: For now just discard all previous clusters. - if (count($this->faceClusters->findByUserId($userId)) > 0) { - $this->faceClusters->deleteAll(); + unset($this->unconnectedVertices[$minVertex]); + $minVertexVector = $this->dataset->sample($minVertex); + + foreach ($this->unconnectedVertices as $vertex) { + $distance = $this->distanceKernel->distance($minVertex, $minVertexVector, $vertex, $this->dataset->sample($vertex)); + if ($this->edges[$vertex][1] > $distance) { + $this->edges[$vertex] = [$minVertex,$distance]; + } + } } + $executionTime = (microtime(true) - $primsStartTime);// DEBUG + $this->logger->debug('ClusterDebug: Prims algo took '.$executionTime." secs.");// DEBUG + + // Calculate the face clusters based on the minimum spanning tree. + + $mstClusterer = new MstClusterer($this->edges, self::MIN_CLUSTER_SIZE, null, self::MAX_CLUSTER_EDGE_LENGHT, self::MIN_CLUSTER_SEPARATION); + $flatClusters = $mstClusterer->processCluster(); + $numberOfClusteredDetections = 0; foreach ($flatClusters as $flatCluster) { @@ -558,9 +128,21 @@ public function calculateClusters(string $userId): void { $this->faceClusters->insert($cluster); $detectionKeys = $flatCluster->getVertexKeys(); + $clusterCentroid = self::calculateCentroidOfDetections(array_map(static fn ($key) => $unclusteredDetections[$key], $detectionKeys)); + foreach ($detectionKeys as $detectionKey) { - $this->faceDetections->assocWithCluster($detections[$detectionKey], $cluster); + // If threshold is larger than 0 and $clusterCentroid is not the null vector + if ($unclusteredDetections[$detectionKey]->getThreshold() > 0.0 && count(array_filter($clusterCentroid, fn ($el) => $el !== 0.0)) > 0) { + // If a threshold is set for this detection and its vector is farther away from the centroid + // than the threshold, skip assigning this detection to the cluster + $distanceValue = self::distance($clusterCentroid, $unclusteredDetections[$detectionKey]->getVector()); + if ($distanceValue >= $unclusteredDetections[$detectionKey]->getThreshold()) { + continue; + } + } + + $this->faceDetections->assocWithCluster($unclusteredDetections[$detectionKey], $cluster); $numberOfClusteredDetections += 1; } } @@ -569,8 +151,6 @@ public function calculateClusters(string $userId): void { $this->pruneClusters($userId); } - - /** * @throws \OCP\DB\Exception */ @@ -662,38 +242,62 @@ private function findFilesWithDuplicateFaces(array $detections): array { return $filesWithDuplicateFaces; } - /** - * @throws \OCP\DB\Exception - * @param array $clusterIdsToMerge - * @param int $parentClusterId - * @return void - */ - private function mergeClusters(array $clusterIdsToMerge, int $parentClusterId): void { - $clusterIdsToMerge = array_unique($clusterIdsToMerge); - foreach ($clusterIdsToMerge as $childClusterId) { - if ($childClusterId == $parentClusterId) { - continue; - } + private function assignToExistingClusters(string $userId, array $detections): array { + $clusters = $this->faceClusters->findByUserId($userId); - $detections = $this->faceDetections->findByClusterId($childClusterId); - $parentCluster = $this->faceClusters->find($parentClusterId); + if (count($clusters) === 0) { + return $detections; + } + $unclusteredDetections = []; - try { - $childCluster = $this->faceClusters->find($childClusterId); - } catch (\Exception $e) { - $this->logger->debug('ExtraDebug: Child cluster already deleted: ' . $childClusterId); + foreach ($detections as $detection) { + $bestCluster = null; + $bestClusterDistance = 999; + if ($detection->getClusterId() !== null) { continue; } - - foreach ($detections as $detection) { - $this->faceDetections->assocWithCluster($detection, $parentCluster); + foreach ($clusters as $cluster) { + $clusterDetections = $this->faceDetections->findByClusterId($cluster->getId()); + if (count($clusterDetections) > 50) { + $clusterDetections = array_map(fn ($key) => $clusterDetections[$key], array_rand($clusterDetections, 50)); + } + $clusterCentroid = self::calculateCentroidOfDetections($clusterDetections); + if ($detection->getThreshold() > 0 && self::distance($clusterCentroid, $detection->getVector()) >= $detection->getThreshold()) { + continue; + } + foreach ($clusterDetections as $clusterDetection) { + $distance = self::distance($clusterDetection->getVector(), $detection->getVector()); + if ( + $distance <= self::MAX_INNER_CLUSTER_RADIUS + && (!isset($bestCluster) || $distance < $bestClusterDistance) + ) { + $bestCluster = $cluster; + $bestClusterDistance = self::distance($clusterDetection->getVector(), $detection->getVector()); + break; + } + } } - - $this->faceClusters->delete($childCluster); + if ($bestCluster !== null) { + $this->faceDetections->assocWithCluster($detection, $bestCluster); + continue; + } + $unclusteredDetections[] = $detection; } + return $unclusteredDetections; + } + private static ?Euclidean $distance; - return; + /** + * @param list $v1 + * @param list $v2 + * @return float + */ + private static function distance(array $v1, array $v2): float { + if (!isset(self::$distance)) { + self::$distance = new Euclidean(); + } + return self::$distance->compute($v1, $v2); } } From 14c94da3f7b98b5902a43277b9319eaee293539d Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Mon, 9 Jan 2023 17:05:11 +0100 Subject: [PATCH 03/17] Refactor HDBSCAN Signed-off-by: Marcel Klehr --- lib/Clustering/MRDistance.php | 11 +++++ lib/Clustering/MstClusterer.php | 30 +++++++++---- lib/Service/FaceClusterAnalyzer.php | 65 ++++++++++++++--------------- 3 files changed, 65 insertions(+), 41 deletions(-) diff --git a/lib/Clustering/MRDistance.php b/lib/Clustering/MRDistance.php index 17338334..9363f76e 100644 --- a/lib/Clustering/MRDistance.php +++ b/lib/Clustering/MRDistance.php @@ -12,6 +12,10 @@ class MRDistance { private Distance $kernel; + + /** + * @var list $coreDistances + */ private array $coreDistances; private int $coreDistSampleSize; private Labeled $dataset; @@ -29,6 +33,13 @@ public function __construct(int $coreDistSampleSize, Labeled $dataset, Distance $this->kernel = $kernel; } + /** + * @param int $a + * @param list $aVector + * @param int $b + * @param list $bVector + * @return float + */ public function distance(int $a, array $aVector, int $b, array $bVector): float { $distance = $this->kernel->compute($aVector, $bVector); diff --git a/lib/Clustering/MstClusterer.php b/lib/Clustering/MstClusterer.php index a352e821..970784a1 100644 --- a/lib/Clustering/MstClusterer.php +++ b/lib/Clustering/MstClusterer.php @@ -9,21 +9,34 @@ // TODO: core edges are not always stored properly (if two halves of the remaining clusters are both pruned at the same time) // TODO: store vertex lambda length (relative to cluster lambda length) for all vertices for soft clustering. class MstClusterer { + /** + * @var array + */ private array $edges; + /** + * @var array + */ private array $remainingEdges; private float $startingLambda; - private float $finalLambda; private float $clusterWeight; private int $minimumClusterSize; private array $coreEdges; private bool $isRoot; private float $maxEdgeLength; private float $minClusterSeparation; + + /** + * @param array $edges + * @param int $minimumClusterSize + * @param float|null $startingLambda + * @param float $maxEdgeLength + * @param float $minClusterSeparation + */ public function __construct(array $edges, int $minimumClusterSize, ?float $startingLambda = null, float $maxEdgeLength = 0.5, float $minClusterSeparation = 0.1) { //Ascending sort of edges while perserving original keys. $this->edges = $edges; - uasort($this->edges, function ($a, $b) { + uasort($this->edges, static function ($a, $b) { if ($a[1] > $b[1]) { return 1; } @@ -65,7 +78,6 @@ public function processCluster(): array { return []; } - $this->finalLambda = $currentLambda; $this->coreEdges = $this->remainingEdges; return [$this]; @@ -102,10 +114,10 @@ public function processCluster(): array { // Return a list of children if the weight of all children is more than $this->clusterWeight. // Otherwise return the current cluster and discard the children. This way we "choose" a combination // of clusters that weigh the most (i.e. have most (excess of) mass). Always discard the root cluster. - $this->finalLambda = $currentLambda; + $finalLambda = $currentLambda; - $childCluster1 = new MstClusterer($childClusterEdges1, $this->minimumClusterSize, $this->finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); - $childCluster2 = new MstClusterer($childClusterEdges2, $this->minimumClusterSize, $this->finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); + $childCluster1 = new MstClusterer($childClusterEdges1, $this->minimumClusterSize, $finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); + $childCluster2 = new MstClusterer($childClusterEdges2, $this->minimumClusterSize, $finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); // Resolve all chosen child clusters recursively $childClusters = array_merge($childCluster1->processCluster(), $childCluster2->processCluster()); @@ -113,7 +125,7 @@ public function processCluster(): array { $childrenWeight = 0.0; foreach ($childClusters as $childCluster) { $childrenWeight += $childCluster->getClusterWeight(); - array_merge($this->coreEdges, $childCluster->getCoreEdges()); + $this->coreEdges = array_merge($this->coreEdges, $childCluster->getCoreEdges()); } if (($childrenWeight > $this->clusterWeight) || $this->isRoot) { @@ -198,6 +210,10 @@ public function getClusterWeight(): float { return $this->clusterWeight; } + + /** + * @returns list + */ public function getVertexKeys(): array { $vertexKeys = []; diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index 34b1ed90..1123f9e5 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -17,7 +17,7 @@ class FaceClusterAnalyzer { public const MIN_SAMPLE_SIZE = 4; // Conservative value: 10 - public const MIN_CLUSTER_SIZE = 6; // Conservative value: 10 + public const MIN_CLUSTER_SIZE = 5; // Conservative value: 10 public const MAX_CLUSTER_EDGE_LENGHT = 99.0; public const MIN_CLUSTER_SEPARATION = 0.0; // For incremental clustering @@ -28,23 +28,14 @@ class FaceClusterAnalyzer { private FaceDetectionMapper $faceDetections; private FaceClusterMapper $faceClusters; - private TagManager $tagManager; private Logger $logger; - private array $edges; - private Labeled $dataset; - - private MRDistance $distanceKernel; - - public function __construct(FaceDetectionMapper $faceDetections, FaceClusterMapper $faceClusters, TagManager $tagManager, Logger $logger) { + public function __construct(FaceDetectionMapper $faceDetections, FaceClusterMapper $faceClusters, Logger $logger) { $this->faceDetections = $faceDetections; $this->faceClusters = $faceClusters; - $this->tagManager = $tagManager; $this->logger = $logger; } - - /** * @throws \OCP\DB\Exception * @throws \JsonException @@ -67,46 +58,46 @@ public function calculateClusters(string $userId): void { $this->logger->debug('ClusterDebug: Found ' . count($unclusteredDetections) . " unclustered detections. Calculating clusters."); - $this->dataset = new Labeled(array_map(function (FaceDetection $detection): array { + $dataset = new Labeled(array_map(static function (FaceDetection $detection): array { return $detection->getVector(); - }, $detections), array_combine(array_keys($detections), array_keys($detections)), false); + }, $unclusteredDetections), array_combine(array_keys($unclusteredDetections), array_keys($unclusteredDetections)), false); - $this->distanceKernel = new MRDistance(self::MIN_SAMPLE_SIZE, $this->dataset, new Euclidean()); + $distanceKernel = new MRDistance(self::MIN_SAMPLE_SIZE, $dataset, new Euclidean()); $primsStartTime = microtime(true);// DEBUG // Prim's algorithm: - $this->unconnectedVertices = array_combine(array_keys($detections), array_keys($detections)); + $unconnectedVertices = array_combine(array_keys($detections), array_keys($detections)); - $firstVertex = current($this->unconnectedVertices); - $firstVertexVector = $this->dataset->sample($firstVertex); - unset($this->unconnectedVertices[$firstVertex]); + $firstVertex = current($unconnectedVertices); + $firstVertexVector = $dataset->sample($firstVertex); + unset($unconnectedVertices[$firstVertex]); - $this->edges = []; - foreach ($this->unconnectedVertices as $vertex) { - $this->edges[$vertex] = [$firstVertex, $this->distanceKernel->distance($firstVertex, $firstVertexVector, $vertex, $this->dataset->sample($vertex))]; + $edges = []; + foreach ($unconnectedVertices as $vertex) { + $edges[$vertex] = [$firstVertex, $distanceKernel->distance($firstVertex, $firstVertexVector, $vertex, $dataset->sample($vertex))]; } - while (count($this->unconnectedVertices) > 0) { + while (count($unconnectedVertices) > 0) { $minDistance = INF; $minVertex = null; - foreach ($this->unconnectedVertices as $vertex) { - $distance = $this->edges[$vertex][1]; + foreach ($unconnectedVertices as $vertex) { + $distance = $edges[$vertex][1]; if ($distance < $minDistance) { $minDistance = $distance; $minVertex = $vertex; } } - unset($this->unconnectedVertices[$minVertex]); - $minVertexVector = $this->dataset->sample($minVertex); + unset($unconnectedVertices[$minVertex]); + $minVertexVector = $dataset->sample($minVertex); - foreach ($this->unconnectedVertices as $vertex) { - $distance = $this->distanceKernel->distance($minVertex, $minVertexVector, $vertex, $this->dataset->sample($vertex)); - if ($this->edges[$vertex][1] > $distance) { - $this->edges[$vertex] = [$minVertex,$distance]; + foreach ($unconnectedVertices as $vertex) { + $distance = $distanceKernel->distance($minVertex, $minVertexVector, $vertex, $dataset->sample($vertex)); + if ($edges[$vertex][1] > $distance) { + $edges[$vertex] = [$minVertex,$distance]; } } } @@ -116,7 +107,7 @@ public function calculateClusters(string $userId): void { // Calculate the face clusters based on the minimum spanning tree. - $mstClusterer = new MstClusterer($this->edges, self::MIN_CLUSTER_SIZE, null, self::MAX_CLUSTER_EDGE_LENGHT, self::MIN_CLUSTER_SEPARATION); + $mstClusterer = new MstClusterer($edges, self::MIN_CLUSTER_SIZE, null, self::MAX_CLUSTER_EDGE_LENGHT, self::MIN_CLUSTER_SEPARATION); $flatClusters = $mstClusterer->processCluster(); $numberOfClusteredDetections = 0; @@ -209,12 +200,12 @@ public static function calculateCentroidOfDetections(array $detections): array { } foreach ($detections as $detection) { - $sum = array_map(function ($el, $el2) { + $sum = array_map(static function ($el, $el2) { return $el + $el2; }, $detection->getVector(), $sum); } - $centroid = array_map(function ($el) use ($detections) { + $centroid = array_map(static function ($el) use ($detections) { return $el / count($detections); }, $sum); @@ -235,13 +226,19 @@ private function findFilesWithDuplicateFaces(array $detections): array { } /** @var array $filesWithDuplicateFaces */ - $filesWithDuplicateFaces = array_filter($files, function ($detections) { + $filesWithDuplicateFaces = array_filter($files, static function ($detections) { return count($detections) > 1; }); return $filesWithDuplicateFaces; } + /** + * @param string $userId + * @param list $detections + * @return list + * @throws \OCP\DB\Exception + */ private function assignToExistingClusters(string $userId, array $detections): array { $clusters = $this->faceClusters->findByUserId($userId); From bc390c3b28ddaec6f62d74daecddd939c8cadbff Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Mon, 9 Jan 2023 17:05:52 +0100 Subject: [PATCH 04/17] Update psalm baseline Signed-off-by: Marcel Klehr --- psalm-baseline.xml | 115 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 110 insertions(+), 5 deletions(-) diff --git a/psalm-baseline.xml b/psalm-baseline.xml index 5a29181c..96fcd344 100644 --- a/psalm-baseline.xml +++ b/psalm-baseline.xml @@ -182,6 +182,79 @@ getFileId + + + compute + grow + nearest + + + $this->coreDistances[$index] + + + float + + + $this->dataset->sample($index) + + + $this->coreDistances + + + + + $childCluster->getCoreEdges() + $finalLambda + $vertexConnectedFrom + + + $childClusterEdges1 + $childClusterEdges2 + $vertexConnectedTo + $vertexConnectedTo + + + $a[1] + $a[1] + $b[1] + $b[1] + $currentLongestEdge[0] + $currentLongestEdge[1] + + + $childCluster + $childrenWeight + $currentLambda + $currentLambda + $currentLongestEdge + $edgeLength + $finalLambda + $lastLambda + $lastLambda + $this->clusterWeight + $vertexConnectedFrom + + + getClusterWeight + getCoreEdges + + + $childCluster->getClusterWeight() + $currentLambda - $lastLambda + $edgeLength + $edgeLength + $lastLambda + $this->clusterWeight + + + $this->remainingEdges + count($childClusterEdges1) > count($childClusterEdges2) ? $childClusterEdges1 : $childClusterEdges2 + + + $vertexConnectedTo + $vertexConnectedTo + + $value @@ -423,7 +496,9 @@ $item->getPathname() $item->getPathname() - + + $output + $output $output $output $output @@ -468,12 +543,42 @@ - - $results - - + + compute compute + + $key + + + $detectionKeys + $firstVertex + + + $dataset->sample($vertex) + $dataset->sample($vertex) + $firstVertexVector + $minVertexVector + + + $unclusteredDetections[$detectionKey] + $unclusteredDetections[$key] + + + $detectionKey + $detectionKeys + $firstVertex + $flatCluster + + + getVertexKeys + + + $minVertex + + + $unconnectedVertices + $detection->getClusterId() !== null From 0e3085b09c54ac3697fc1ead8f8d4a9125cb6884 Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Mon, 9 Jan 2023 20:16:32 +0100 Subject: [PATCH 05/17] Fix typo Signed-off-by: Marcel Klehr --- lib/Service/FaceClusterAnalyzer.php | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index 1123f9e5..13054760 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -68,7 +68,7 @@ public function calculateClusters(string $userId): void { // Prim's algorithm: - $unconnectedVertices = array_combine(array_keys($detections), array_keys($detections)); + $unconnectedVertices = array_combine(array_keys($unclusteredDetections), array_keys($unclusteredDetections)); $firstVertex = current($unconnectedVertices); $firstVertexVector = $dataset->sample($firstVertex); From 122b25027503bd034ba14e7c7ad762e01a6df71a Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Mon, 9 Jan 2023 20:16:45 +0100 Subject: [PATCH 06/17] InstallDeps: Add more types Signed-off-by: Marcel Klehr --- lib/Migration/InstallDeps.php | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Migration/InstallDeps.php b/lib/Migration/InstallDeps.php index dbb5558d..e84565dd 100644 --- a/lib/Migration/InstallDeps.php +++ b/lib/Migration/InstallDeps.php @@ -92,7 +92,7 @@ public function run(IOutput $output): void { $this->runFfmpegInstall($binaryPath); } - protected function installNodeBinary($output) : void { + protected function installNodeBinary(IOutput $output) : void { $isARM = false; $isMusl = false; $uname = php_uname('m'); From ca06556484ed2d9737480eeabfe9116c54447d6c Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Mon, 9 Jan 2023 20:23:57 +0100 Subject: [PATCH 07/17] Fix tests Signed-off-by: Marcel Klehr --- test/ClusterTest.php | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ClusterTest.php b/test/ClusterTest.php index 9d4565fc..1301a5ad 100644 --- a/test/ClusterTest.php +++ b/test/ClusterTest.php @@ -326,13 +326,13 @@ public function testClusterTemptClusterMerging() { self::assertCount(3, $clusters); $detections = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER * 2, $detections); + self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER, $detections); $detections = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER, $detections); $detections = $this->faceDetectionMapper->findByClusterId($clusters[2]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER, $detections); + self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER * 2, $detections); } private static function getNullVector() { From 171cf14acbaeb9db8785d1d14aea708138264fde Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Wed, 11 Jan 2023 15:12:26 +0100 Subject: [PATCH 08/17] Stabilize ClusterTemptClusterMerging test Signed-off-by: Marcel Klehr --- test/ClusterTest.php | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/test/ClusterTest.php b/test/ClusterTest.php index 1301a5ad..a097a433 100644 --- a/test/ClusterTest.php +++ b/test/ClusterTest.php @@ -325,14 +325,13 @@ public function testClusterTemptClusterMerging() { $clusters = $this->faceClusterMapper->findByUserId(self::TEST_USER1); self::assertCount(3, $clusters); - $detections = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER, $detections); + $detections1 = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); + $detections2 = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); + $detections3 = $this->faceDetectionMapper->findByClusterId($clusters[2]->getId()); + $counts = [count($detections1), count($detections2), count($detections3)]; - $detections = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER, $detections); - - $detections = $this->faceDetectionMapper->findByClusterId($clusters[2]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER * 2, $detections); + self::assertCount(2, array_filter($counts, fn($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER)); + self::assertCount(1, array_filter($counts, fn($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER * 2)); } private static function getNullVector() { From 6941bdbb7414a8f7bcc602cef2cdbba788a3e189 Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Wed, 11 Jan 2023 15:57:09 +0100 Subject: [PATCH 09/17] Fix ClusterMergingWithThirdAdditionalCluster test Signed-off-by: Marcel Klehr --- test/ClusterTest.php | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/ClusterTest.php b/test/ClusterTest.php index a097a433..e7cff954 100644 --- a/test/ClusterTest.php +++ b/test/ClusterTest.php @@ -238,6 +238,7 @@ public function testClusterMergingWithThirdAdditionalCluster() { $detection->setClusterId($clusters[0]->getId()); $this->faceDetectionMapper->update($detection); } + $this->faceClusterMapper->delete($clusters[0]); $numOfDetections = self::INITIAL_DETECTIONS_PER_CLUSTER; $clusterValue = 3 * self::INITIAL_DETECTIONS_PER_CLUSTER; @@ -259,11 +260,12 @@ public function testClusterMergingWithThirdAdditionalCluster() { $clusters = $this->faceClusterMapper->findByUserId(self::TEST_USER1); self::assertCount(2, $clusters); - $detections = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER * 2, $detections); + $detections1 = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); + $detections2 = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); + $counts = [count($detections1), count($detections2)]; - $detections = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER, $detections); + self::assertCount(1, array_filter($counts, fn($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER)); + self::assertCount(1, array_filter($counts, fn($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER * 2)); } /** From e7a24163e5c144aaa012568b6114e404f38d7aba Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Sun, 5 Feb 2023 18:09:25 +0100 Subject: [PATCH 10/17] Next Generation Signed-off-by: Marcel Klehr --- lib/Clustering/DualTreeBall.php | 130 ++++++ lib/Clustering/DualTreeClique.php | 98 +++++ lib/Clustering/HDBSCAN.php | 168 +++++++ lib/Clustering/MRDistance.php | 57 --- lib/Clustering/MrdBallTree.php | 649 ++++++++++++++++++++++++++++ lib/Clustering/MstClusterer.php | 200 +++++---- lib/Clustering/MstSolver.php | 271 ++++++++++++ lib/Clustering/SquaredDistance.php | 64 +++ lib/Db/FaceDetectionMapper.php | 36 +- lib/Service/FaceClusterAnalyzer.php | 84 +--- 10 files changed, 1558 insertions(+), 199 deletions(-) create mode 100644 lib/Clustering/DualTreeBall.php create mode 100644 lib/Clustering/DualTreeClique.php create mode 100644 lib/Clustering/HDBSCAN.php delete mode 100644 lib/Clustering/MRDistance.php create mode 100644 lib/Clustering/MrdBallTree.php create mode 100644 lib/Clustering/MstSolver.php create mode 100644 lib/Clustering/SquaredDistance.php diff --git a/lib/Clustering/DualTreeBall.php b/lib/Clustering/DualTreeBall.php new file mode 100644 index 00000000..cfcbf799 --- /dev/null +++ b/lib/Clustering/DualTreeBall.php @@ -0,0 +1,130 @@ +longestDistanceInNode) { + $this->longestDistanceInNode = $longestDistance; + } + } + + public function getLongestDistance(): float { + return $this->longestDistanceInNode; + } + + public function resetLongestEdge(): void { + $this->longestDistanceInNode = INF; + foreach ($this->children() as $child) { + $child->resetLongestEdge(); + } + } + + public function resetFullyConnectedStatus(): void { + $this->fullyConnected = false; + foreach ($this->children() as $child) { + $child->resetFullyConnectedStatus(); + } + } + + public function getSetId() { + if (!$this->fullyConnected) { + return null; + } + return $this->setId; + } + + public function isFullyConnected(): bool { + return $this->fullyConnected; + } + + public function propagateSetChanges(array &$labelToSetId) { + if ($this->fullyConnected) { + // If we are already fully connected, we just need to check if the + // set id has changed + foreach ($this->children() as $child) { + $this->setId = $child->propagateSetChanges($labelToSetId); + } + + return $this->setId; + } + + // If, and only if, both children are fully connected and in the same set id then + // we, too, are fully connected + $setId = null; + foreach ($this->children() as $child) { + $retVal = $child->propagateSetChanges($labelToSetId); + + if ($retVal === null) { + return null; + } + + if ($setId !== null && $setId !== $retVal) { + return null; + } + + $setId = $retVal; + } + + $this->setId = $setId; + $this->fullyConnected = true; + + return $this->setId; + } + + /** + * Factory method to build a hypersphere by splitting the dataset into left and right clusters. + * + * @param \Rubix\ML\Datasets\Labeled $dataset + * @param \Rubix\ML\Kernels\Distance\Distance $kernel + * @return self + */ + public static function split(Labeled $dataset, Distance $kernel): self { + $center = []; + + foreach ($dataset->features() as $column => $values) { + if ($dataset->featureType($column)->isContinuous()) { + $center[] = Stats::mean($values); + } else { + $center[] = argmax(array_count_values($values)); + } + } + + $distances = []; + + foreach ($dataset->samples() as $sample) { + $distances[] = $kernel->compute($sample, $center); + } + + $radius = max($distances) ?: 0.0; + + $leftCentroid = $dataset->sample(argmax($distances)); + + $distances = []; + + foreach ($dataset->samples() as $sample) { + $distances[] = $kernel->compute($sample, $leftCentroid); + } + + $rightCentroid = $dataset->sample(argmax($distances)); + + $subsets = $dataset->spatialSplit($leftCentroid, $rightCentroid, $kernel); + + return new self($center, $radius, $subsets); + } +} diff --git a/lib/Clustering/DualTreeClique.php b/lib/Clustering/DualTreeClique.php new file mode 100644 index 00000000..910f0311 --- /dev/null +++ b/lib/Clustering/DualTreeClique.php @@ -0,0 +1,98 @@ +longestDistanceInNode = $longestDistance; + } + + public function getLongestDistance(): float { + return $this->longestDistanceInNode; + } + + public function resetLongestEdge(): void { + $this->longestDistanceInNode = INF; + } + + public function resetFullyConnectedStatus(): void { + $this->fullyConnected = false; + } + + public function getSetId() { + if (!$this->fullyConnected) { + return null; + } + + return $this->setId; + } + + public function isFullyConnected(): bool { + return $this->fullyConnected; + } + + public function propagateSetChanges(array &$labelToSetId) { + if ($this->fullyConnected) { + $this->setId = $labelToSetId[$this->dataset->label(0)]; + return $this->setId; + } + + $labels = $this->dataset->labels(); + + $label = + + $setId = $labelToSetId[array_pop($labels)]; + + foreach ($labels as $label) { + if ($setId !== $labelToSetId[$label]) { + return null; + } + } + + $this->fullyConnected = true; + $this->setId = $setId; + + return $this->setId; + } + + /** + * Terminate a branch with a dataset. + * + * @param \Rubix\ML\Datasets\Labeled $dataset + * @param \Rubix\ML\Kernels\Distance\Distance $kernel + * @return self + */ + public static function terminate(Labeled $dataset, Distance $kernel): self { + $center = []; + + foreach ($dataset->features() as $column => $values) { + if ($dataset->featureType($column)->isContinuous()) { + $center[] = Stats::mean($values); + } else { + $center[] = argmax(array_count_values($values)); + } + } + + $distances = []; + + foreach ($dataset->samples() as $sample) { + $distances[] = $kernel->compute($sample, $center); + } + $radius = max($distances) ?: 0.0; + return new self($dataset, $center, $radius); + } +} diff --git a/lib/Clustering/HDBSCAN.php b/lib/Clustering/HDBSCAN.php new file mode 100644 index 00000000..9d515198 --- /dev/null +++ b/lib/Clustering/HDBSCAN.php @@ -0,0 +1,168 @@ +minClusterSize = $minClusterSize; + $this->mstSolver = new MstSolver($dataset, 20, $sampleSize, $kernel, $oldCoreDistances, $useTrueMst); + } + + public function getCoreNeighborDistances(): array { + return $this->mstSolver->getCoreNeighborDistances(); + } + + /** + * Return the estimator type. + * + * @return \Rubix\ML\EstimatorType + */ + public function type(): EstimatorType { + return EstimatorType::clusterer(); + } + + /** + * Return the data types that the estimator is compatible with. + * + * @return list<\Rubix\ML\DataType> + */ + public function compatibility(): array { + return $this->mstSolver->kernel()->compatibility(); + } + + /** + * Return the settings of the hyper-parameters in an associative array. + * + * @return mixed[] + */ + public function params(): array { + return [ + 'min cluster size' => $this->minClusterSize, + 'sample size' => $this->sampleSize, + 'dual tree' => $this->mstSolver, + ]; + } + + /** + * Form clusters and make predictions from the dataset (hard clustering). + * + * @return list//@return list + */ + public function predict(): array { + // Boruvka algorithm for MST generation + $edges = $this->mstSolver->getMst(); + + // Boruvka complete, $edges now contains our mutual reachability distance MST + if ($this->mstSolver->kernel() instanceof SquaredDistance) { + foreach ($edges as &$edge) { + $edge["distance"] = sqrt($edge["distance"]); + } + } + unset($edge); + + // TODO: Min cluster separation/edge length of MstClusterer to the caller of this class + $mstClusterer = new MstClusterer($edges, null, $this->minClusterSize, null, 0.0); + $flatClusters = $mstClusterer->processCluster(); + + return $flatClusters; + } + + /** + * Return the string representation of the object. + * + * @internal + * + * @return string + */ + public function __toString(): string { + return 'HDBSCAN (' . Params::stringify($this->params()) . ')'; + } +} diff --git a/lib/Clustering/MRDistance.php b/lib/Clustering/MRDistance.php deleted file mode 100644 index 9363f76e..00000000 --- a/lib/Clustering/MRDistance.php +++ /dev/null @@ -1,57 +0,0 @@ - $coreDistances - */ - private array $coreDistances; - private int $coreDistSampleSize; - private Labeled $dataset; - private BallTree $distanceTree; - - public function __construct(int $coreDistSampleSize, Labeled $dataset, Distance $kernel) { - $this->coreDistSampleSize = $coreDistSampleSize; - $this->kernel = $kernel; - $this->coreDistances = []; - $this->dataset = $dataset; - - $this->distanceTree = new BallTree($coreDistSampleSize * 3, $kernel); - $this->distanceTree->grow($dataset); - - $this->kernel = $kernel; - } - - /** - * @param int $a - * @param list $aVector - * @param int $b - * @param list $bVector - * @return float - */ - public function distance(int $a, array $aVector, int $b, array $bVector): float { - $distance = $this->kernel->compute($aVector, $bVector); - - return max($distance, $this->getCoreDistance($a), $this->getCoreDistance($b)); - } - - private function getCoreDistance(int $index): float { - if (!isset($this->coreDistances[$index])) { - [$_1, $_2, $distances] = $this->distanceTree->nearest($this->dataset->sample($index), $this->coreDistSampleSize); - $this->coreDistances[$index] = end($distances); - } - - return $this->coreDistances[$index]; - } -} diff --git a/lib/Clustering/MrdBallTree.php b/lib/Clustering/MrdBallTree.php new file mode 100644 index 00000000..1d738431 --- /dev/null +++ b/lib/Clustering/MrdBallTree.php @@ -0,0 +1,649 @@ +maxLeafSize = $maxLeafSize; + $this->sampleSize = $sampleSize; + + $this->kernel = $kernel ?? new SquaredDistance(); + $this->radiusDiamFactor = $this->kernel instanceof SquaredDistance ? 4 : 2; + + $this->nodeDistances = []; + $this->nodeIds = new \SplObjectStorage(); + } + + public function getCoreNeighborDistances(): array { + return $this->coreNeighborDistances; + } + + public function nodeDistance($queryNode, $referenceNode): float { + // Use cache to accelerate repeated queries + if ($this->nodeIds->contains($queryNode)) { + $queryNodeId = $this->nodeIds[$queryNode]; + } else { + $queryNodeId = $this->nodeIds->count(); + $this->nodeIds[$queryNode] = $queryNodeId; + } + + if ($this->nodeIds->contains($referenceNode)) { + $referenceNodeId = $this->nodeIds[$referenceNode]; + } else { + $referenceNodeId = $this->nodeIds->count(); + $this->nodeIds[$referenceNode] = $referenceNodeId; + } + + if ($referenceNodeId === $queryNodeId) { + return (-$this->radiusDiamFactor * $queryNode->radius()); + } + + $smallIndex = min($queryNodeId, $referenceNodeId); + $largeIndex = max($queryNodeId, $referenceNodeId); + + if (isset($this->nodeDistances[$smallIndex][$largeIndex])) { + $nodeDistance = $this->nodeDistances[$smallIndex][$largeIndex]; + } else { + $nodeDistance = $this->kernel->compute($queryNode->center(), $referenceNode->center()); + + if ($this->kernel instanceof SquaredDistance) { + $nodeDistance = sqrt($nodeDistance) - sqrt($queryNode->radius()) - sqrt($referenceNode->radius()); + $nodeDistance = abs($nodeDistance) * $nodeDistance; + } else { + $nodeDistance = $nodeDistance - $queryNode->radius() - $referenceNode->radius(); + } + + $this->nodeDistances[$smallIndex][$largeIndex] = $nodeDistance; + } + + return $nodeDistance; + } + + /** + * Get tree root. + * + * @internal + * + * @return DualTreeBall + */ + public function getRoot(): Ball { + return $this->root; + } + + /** + * Get the dataset the tree was grown on. + * + * @internal + * + * @return Labeled + */ + public function getDataset(): Labeled { + return $this->dataset; + } + + private function updateNearestNeighbors($queryNode, $referenceNode, $k, $maxRange, &$bestDistances): void { + $querySamples = $queryNode->dataset()->samples(); + $queryLabels = $queryNode->dataset()->labels(); + $referenceSamples = $referenceNode->dataset()->samples(); + $referenceLabels = $referenceNode->dataset()->labels(); + + $longestDistance = 0.0; + $shortestDistance = INF; + + foreach ($querySamples as $queryKey => $querySample) { + $queryLabel = $queryLabels[$queryKey]; + + $bestDistance = $bestDistances[$queryLabel]; + + foreach ($referenceSamples as $referenceKey => $referenceSample) { + $referenceLabel = $referenceLabels[$referenceKey]; + + if ($queryLabel === $referenceLabel) { + continue; + } + + // Calculate native distance + $distance = $this->cachedComputeNative($queryLabel, $querySample, $referenceLabel, $referenceSample); + + if ($distance < $bestDistance) { + //Minimize array queries within these loops: + $coreNeighborDistances = & $this->coreNeighborDistances[$queryLabel]; + $coreNeighborDistances[$referenceLabel] = $distance; + + if (count($coreNeighborDistances) >= $k) { + asort($coreNeighborDistances); + $coreNeighborDistances = array_slice($coreNeighborDistances, 0, $k, true); + $bestDistance = min(end($coreNeighborDistances), $maxRange); + } + } + } + + if ($bestDistance > $longestDistance) { + $longestDistance = $bestDistance; + } + + if ($bestDistance < $shortestDistance) { + $shortestDistance = $bestDistance; + } + $bestDistances[$queryLabel] = $bestDistance; + } + + if ($this->kernel instanceof SquaredDistance) { + $longestDistance = min($longestDistance, (2 * sqrt($queryNode->radius()) + sqrt($shortestDistance)) ** 2); + } else { + $longestDistance = min($longestDistance, 2 * $queryNode->radius() + $shortestDistance); + } + $queryNode->setLongestDistance($longestDistance); + } + + private function findNearestNeighbors($queryNode, $referenceNode, $k, $maxRange, &$bestDistances): void { + $nodeDistance = $this->nodeDistance($queryNode, $referenceNode); + + if ($nodeDistance > 0.0) { + // Calculate smallest possible bound (i.e., d(Q) ): + $currentBound = $queryNode->getLongestDistance(); + + // If node distance is greater than the longest possible edge in this node, + // prune this reference node + if ($nodeDistance > $currentBound) { + return; + } + } + + if ($queryNode instanceof DualTreeClique && $referenceNode instanceof DualTreeClique) { + $this->updateNearestNeighbors($queryNode, $referenceNode, $k, $maxRange, $bestDistances); + return; + } + + if ($queryNode instanceof DualTreeClique) { + foreach ($referenceNode->children() as $child) { + $this->findNearestNeighbors($queryNode, $child, $k, $maxRange, $bestDistances); + } + return; + } + + if ($referenceNode instanceof DualTreeClique) { + $longestDistance = 0.0; + + $queryLeft = $queryNode->left(); + $queryRight = $queryNode->right(); + + $this->findNearestNeighbors($queryLeft, $referenceNode, $k, $maxRange, $bestDistances); + $this->findNearestNeighbors($queryRight, $referenceNode, $k, $maxRange, $bestDistances); + } else { + // --> if ($queryNode instanceof DualTreeBall && $referenceNode instanceof DualTreeBall) + $queryLeft = $queryNode->left(); + $queryRight = $queryNode->right(); + $referenceLeft = $referenceNode->left(); + $referenceRight = $referenceNode->right(); + + // TODO: traverse closest neighbor nodes first + $this->findNearestNeighbors($queryLeft, $referenceLeft, $k, $maxRange, $bestDistances); + $this->findNearestNeighbors($queryLeft, $referenceRight, $k, $maxRange, $bestDistances); + $this->findNearestNeighbors($queryRight, $referenceLeft, $k, $maxRange, $bestDistances); + $this->findNearestNeighbors($queryRight, $referenceRight, $k, $maxRange, $bestDistances); + } + + $longestLeft = $queryLeft->getLongestDistance(); + $longestRight = $queryRight->getLongestDistance(); + + // TODO: min($longestLeft, $longestRight) + 2 * ($queryNode->radius()) <--- Can be made tighter by using the shortest distance from child. + if ($this->kernel instanceof SquaredDistance) { + $longestDistance = max($longestLeft, $longestRight); + $longestLeft = (sqrt($longestLeft) + 2 * (sqrt($queryNode->radius()) - sqrt($queryLeft->radius()))) ** 2; + $longestRight = (sqrt($longestRight) + 2 * (sqrt($queryNode->radius()) - sqrt($queryRight->radius()))) ** 2; + $longestDistance = min($longestDistance, min($longestLeft, $longestRight), (sqrt(min($longestLeft, $longestRight)) + 2 * (sqrt($queryNode->radius()))) ** 2); + } else { + $longestDistance = max($longestLeft, $longestRight); + $longestLeft = $longestLeft + 2 * ($queryNode->radius() - $queryLeft->radius()); + $longestRight = $longestRight + 2 * ($queryNode->radius() - $queryRight->radius()); + $longestDistance = min($longestDistance, min($longestLeft, $longestRight), min($longestLeft, $longestRight) + 2 * ($queryNode->radius())); + } + + $queryNode->setLongestDistance($longestDistance); + + return; + } + + public function kNearestAll($k, float $maxRange = INF): void { + $this->coreNeighborDistances = []; + + $allLabels = $this->dataset->labels(); + $bestDistances = []; + foreach ($allLabels as $label) { + $bestDistances[$label] = $maxRange; + } + + $treeRoot = $this->root; + + $treeRoot->resetFullyConnectedStatus(); + $treeRoot->resetLongestEdge(); + + $this->findNearestNeighbors($treeRoot, $treeRoot, $k, $maxRange, $bestDistances); + } + + /** + * Precompute core distances for the current dataset to accelerate + * subsequent MRD queries. Optionally, utilize core distances that've + * been previously determined for (a subset of) the current dataset. + * Returns the updated core distances for future use. + * + * @internal + * + * @param array|null $oldCoreNeighbors + * @return array + */ + + public function precalculateCoreDistances(?array $oldCoreNeighbors = null) { + if (empty($this->dataset)) { + throw new \Exception("Precalculation of core distances requested but dataset is empty. Call ->grow() first!"); + } + + $labels = $this->dataset->labels(); + + if ($oldCoreNeighbors !== null && !empty($oldCoreNeighbors) && count(reset($oldCoreNeighbors)) >= $this->sampleSize) { + // Determine the search radius for core distances based on the largest old + // core distance (points father than that cannot shorten the old core distances) + $largestOldCoreDistance = 0.0; + + // Utilize old (possibly stale) core distance data + foreach ($oldCoreNeighbors as $label => $oldDistances) { + $coreDistance = (array_values($oldDistances))[$this->sampleSize - 1]; + + if ($coreDistance > $largestOldCoreDistance) { + $largestOldCoreDistance = $coreDistance; + } + + $this->coreNeighborDistances[$label] = $oldDistances; + $this->coreDistances[$label] = $coreDistance; + } + + $updatedOldCoreLabels = []; + + // Don't recalculate core distances for the old labels + $labels = array_filter($labels, function ($label) use ($oldCoreNeighbors) { + return !isset($oldCoreNeighbors[$label]); + }); + + foreach ($labels as $label) { + [$neighborLabels, $neighborDistances] = $this->cachedRange($label, $largestOldCoreDistance); + // TODO: cachedRange may not return $this->sampleSize number of labels. + $this->coreNeighborDistances[$label] = array_combine(array_slice($neighborLabels, 0, $this->sampleSize), array_slice($neighborDistances, 0, $this->sampleSize)); + $this->coreDistances[$label] = $neighborDistances[$this->sampleSize - 1]; + + // If one of the old vertices is within the update radius of this new vertex, + // check whether the old core distance needs to be updated. + foreach ($neighborLabels as $distanceKey => $neighborLabel) { + if (isset($oldCoreNeighbors[$neighborLabel])) { + $newDistance = $neighborDistances[$distanceKey]; + if ($newDistance < $this->coreDistances[$neighborLabel]) { + $this->coreNeighborDistances[$neighborLabel][$label] = $newDistance; + $updatedOldCoreLabels[$neighborLabel] = true; + } + } + } + } + + foreach (array_keys($updatedOldCoreLabels) as $label) { + asort($this->coreNeighborDistances[$label]); + $this->coreNeighborDistances[$label] = array_slice($this->coreNeighborDistances[$label], 0, $this->sampleSize, true); + $this->coreDistances[$label] = end($this->coreNeighborDistances[$label]); + } + } else { // $oldCoreNeighbors === null + $this->kNearestAll($this->sampleSize, INF); + + foreach ($this->dataset->labels() as $label) { + $this->coreDistances[$label] = end($this->coreNeighborDistances[$label]); + } + } + + return $this->coreNeighborDistances; + } + + /** + * Inserts a new neighbor to core neighbors if the distance + * is greater than the current largest distance for the query label. + * + * Returns the updated core distance or INF if there are less than $this->sampleSize neighbors. + * + * @internal + * + * @param mixed $queryLabel + * @param mixed $referenceLabel + * @param float $distance + * @return float + */ + private function insertToCoreDistances($queryLabel, $referenceLabel, $distance): float { + // Update the core distances of the queryLabel + if (isset($this->coreDistances[$queryLabel])) { + if ($this->coreDistances[$queryLabel] > $distance) { + $this->coreNeighborDistances[$queryLabel][$referenceLabel] = $distance; + asort($this->coreNeighborDistances[$queryLabel]); + + $this->coreNeighborDistances[$queryLabel] = array_slice($this->coreNeighborDistances[$queryLabel], 0, $this->sampleSize, true); + $this->coreDistances[$queryLabel] = end($this->coreNeighborDistances[$queryLabel]); + } + } else { + $this->coreNeighborDistances[$queryLabel][$referenceLabel] = $distance; + + if (count($this->coreNeighborDistances[$queryLabel]) >= $this->sampleSize) { + asort($this->coreNeighborDistances[$queryLabel]); + + $this->coreNeighborDistances[$queryLabel] = array_slice($this->coreNeighborDistances[$queryLabel], 0, $this->sampleSize, true); + $this->coreDistances[$queryLabel] = end($this->coreNeighborDistances[$queryLabel]); + } + } + + // Update the core distances of the referenceLabel (this is not necessary, but *may* accelerate the algo slightly) + if (isset($this->coreDistances[$referenceLabel])) { + if ($this->coreDistances[$referenceLabel] > $distance) { + $this->coreNeighborDistances[$referenceLabel][$queryLabel] = $distance; + asort($this->coreNeighborDistances[$referenceLabel]); + + $this->coreNeighborDistances[$referenceLabel] = array_slice($this->coreNeighborDistances[$referenceLabel], 0, $this->sampleSize, true); + $this->coreDistances[$referenceLabel] = end($this->coreNeighborDistances[$referenceLabel]); + } + } else { + $this->coreNeighborDistances[$referenceLabel][$queryLabel] = $distance; + + if (count($this->coreNeighborDistances[$referenceLabel]) > $this->sampleSize) { + asort($this->coreNeighborDistances[$referenceLabel]); + $this->coreNeighborDistances[$referenceLabel] = array_slice($this->coreNeighborDistances[$referenceLabel], 0, $this->sampleSize, true); + $this->coreDistances[$referenceLabel] = end($this->coreNeighborDistances[$referenceLabel]); + } + } + + return $this->coreDistances[$queryLabel] ?? INF; + } + + /** + * Compute the mutual reachability distance between two vectors. + * + * @internal + * + * @param int|string $a + * @param int|string $b + * @return float + */ + public function computeMrd($a, array $a_vector, $b, array $b_vector): float { + $distance = $this->cachedComputeNative($a, $a_vector, $b, $b_vector); + + return max($distance, $this->getCoreDistance($a), $this->getCoreDistance($b)); + } + + public function getCoreDistance($label): float { + if (!isset($this->coreDistances[$label])) { + [$labels, $distances] = $this->getCoreNeighbors($label); + $this->coreDistances[$label] = end($distances); + } + + return $this->coreDistances[$label]; + } + + public function cachedComputeNative($a, array $a_vector, $b, array $b_vector, bool $storeNewCalculations = true): float { + if (isset($this->coreNeighborDistances[$a][$b])) { + return $this->coreNeighborDistances[$a][$b]; + } + if (isset($this->coreNeighborDistances[$b][$a])) { + return $this->coreNeighborDistances[$b][$a]; + } + + if ($a < $b) { + $smallIndex = $a; + $largeIndex = $b; + } else { + $smallIndex = $b; + $largeIndex = $a; + } + + if (!isset($this->nativeInterpointCache[$smallIndex][$largeIndex])) { + $distance = $this->kernel->compute($a_vector, $b_vector); + if ($storeNewCalculations) { + $this->nativeInterpointCache[$smallIndex][$largeIndex] = $distance; + } + return $distance; + } + + return $this->nativeInterpointCache[$smallIndex][$largeIndex]; + } + + /** + * Run a n nearest neighbors search on a single label and return the neighbor labels, and distances in a 2-tuple + * + * + * @internal + * + * @param int|string $sampleLabel + * @param bool $useCachedValues + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + * @return array{list,list} + */ + public function getCoreNeighbors($sampleLabel, bool $useCachedValues = true): array { + if ($useCachedValues && isset($this->coreNeighborDistances[$sampleLabel])) { + return [array_keys($this->coreNeighborDistances[$sampleLabel]), array_values($this->coreNeighborDistances[$sampleLabel])]; + } + + $sampleKey = array_search($sampleLabel, $this->dataset->labels()); + $sample = $this->dataset->sample($sampleKey); + + $squaredDistance = $this->kernel instanceof SquaredDistance; + + /** @var list **/ + $stack = [$this->root]; + $stackDistances = [0.0]; + $radius = INF; + + $labels = $distances = []; + + while ($current = array_pop($stack)) { + $currentDistance = array_pop($stackDistances); + + if ($currentDistance > $radius) { + continue; + } + + if ($current instanceof DualTreeBall) { + foreach ($current->children() as $child) { + if ($child instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $child->center()); + + if ($squaredDistance) { + $distance = sqrt($distance); + $childRadius = sqrt($child->radius()); + $distance = $distance - $childRadius; + $distance = abs($distance) * $distance; + } else { + $distance = $distance - $child->radius(); + } + + if ($distance < $radius) { + $stack[] = $child; + $stackDistances[] = $distance; + } + } + } + array_multisort($stackDistances, SORT_DESC, $stack); + } elseif ($current instanceof DualTreeClique) { + $dataset = $current->dataset(); + $neighborLabels = $dataset->labels(); + + foreach ($dataset->samples() as $i => $neighbor) { + if ($neighborLabels[$i] === $sampleLabel) { + continue; + } + + $distance = $this->cachedComputeNative($sampleLabel, $sample, $neighborLabels[$i], $neighbor); + + if ($distance <= $radius) { + $labels[] = $neighborLabels[$i]; + $distances[] = $distance; + } + } + + if (count($labels) >= $this->sampleSize) { + array_multisort($distances, $labels); + $radius = $distances[$this->sampleSize - 1]; + $labels = array_slice($labels, 0, $this->sampleSize); + $distances = array_slice($distances, 0, $this->sampleSize); + } + } + } + return [$labels, $distances]; + } + + /** + * Return all labels, and distances within a given radius of a sample. + * + * + * @internal + * + * @param int $sampleLabel + * @param float $radius + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + * @throws \Rubix\ML\Exceptions\RuntimeException + * @return array{list,list} + */ + public function cachedRange($sampleLabel, float $radius): array { + $sampleKey = array_search($sampleLabel, $this->dataset->labels()); + $sample = $this->dataset->sample($sampleKey); + + $squaredDistance = $this->kernel instanceof SquaredDistance; + + /** @var list **/ + $stack = [$this->root]; + + $labels = $distances = []; + + while ($current = array_pop($stack)) { + if ($current instanceof DualTreeBall) { + foreach ($current->children() as $child) { + if ($child instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $child->center()); + + if ($squaredDistance) { + $distance = sqrt($distance); + $childRadius = sqrt($child->radius()); + $minDistance = $distance - $childRadius; + $minDistance = abs($minDistance) * $minDistance; + $maxDistance = ($distance + $childRadius) ** 2; + } else { + $childRadius = $child->radius(); + $minDistance = $distance - $childRadius; + $maxDistance = $distance + $childRadius; + } + + if ($minDistance < $radius) { + if ($maxDistance < $radius && $child instanceof DualTreeBall) { + // The whole child is within the specified radius: greedily add all sub-children recursively to the stack + $subStack = [$child]; + while ($subStackCurrent = array_pop($subStack)) { + foreach ($subStackCurrent->children() as $subChild) { + if ($subChild instanceof DualTreeClique) { + $stack[] = $subChild; + } else { + $subStack[] = $subChild; + } + } + } + } else { + $stack[] = $child; + } + } + } + } + } elseif ($current instanceof DualTreeClique) { + $dataset = $current->dataset(); + $neighborLabels = $dataset->labels(); + + foreach ($dataset->samples() as $i => $neighbor) { + $distance = $this->cachedComputeNative($sampleLabel, $sample, $neighborLabels[$i], $neighbor); + + if ($distance <= $radius) { + $labels[] = $neighborLabels[$i]; + $distances[] = $distance; + } + } + } + } + array_multisort($distances, $labels); + return [$labels, $distances]; + } + + /** + * Insert a root node and recursively split the dataset until a terminating + * condition is met. This also sets the dataset that will be used to calculate + * core distances. Previously calculated core distances will be stored/used + * despite calling grow, unless precalculateCoreDistances() is called again. + * + * @internal + * + * @param \Rubix\ML\Datasets\Labeled $dataset + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + */ + public function grow(Labeled $dataset): void { + $this->dataset = $dataset; + $this->root = DualTreeBall::split($dataset, $this->kernel); + + $stack = [$this->root]; + + while ($current = array_pop($stack)) { + [$left, $right] = $current->subsets(); + + $current->cleanup(); + + if ($left->numSamples() > $this->maxLeafSize) { + $node = DualTreeBall::split($left, $this->kernel); + + $current->attachLeft($node); + + $stack[] = $node; + } elseif (!$left->empty()) { + $current->attachLeft(DualTreeClique::terminate($left, $this->kernel)); + } + + if ($right->numSamples() > $this->maxLeafSize) { + $node = DualTreeBall::split($right, $this->kernel); + + if ($node->isPoint()) { + $current->attachRight(DualTreeClique::terminate($right, $this->kernel)); + } else { + $current->attachRight($node); + + $stack[] = $node; + } + } elseif (!$right->empty()) { + $current->attachRight(DualTreeClique::terminate($right, $this->kernel)); + } + } + } +} diff --git a/lib/Clustering/MstClusterer.php b/lib/Clustering/MstClusterer.php index 970784a1..849f5d3e 100644 --- a/lib/Clustering/MstClusterer.php +++ b/lib/Clustering/MstClusterer.php @@ -9,38 +9,25 @@ // TODO: core edges are not always stored properly (if two halves of the remaining clusters are both pruned at the same time) // TODO: store vertex lambda length (relative to cluster lambda length) for all vertices for soft clustering. class MstClusterer { - /** - * @var array - */ private array $edges; - /** - * @var array - */ private array $remainingEdges; private float $startingLambda; private float $clusterWeight; private int $minimumClusterSize; private array $coreEdges; private bool $isRoot; - private float $maxEdgeLength; + private array $mapVerticesToEdges; private float $minClusterSeparation; - /** - * @param array $edges - * @param int $minimumClusterSize - * @param float|null $startingLambda - * @param float $maxEdgeLength - * @param float $minClusterSeparation - */ - public function __construct(array $edges, int $minimumClusterSize, ?float $startingLambda = null, float $maxEdgeLength = 0.5, float $minClusterSeparation = 0.1) { + public function __construct(array $edges, ?array $mapVerticesToEdges, int $minimumClusterSize, ?float $startingLambda = null, float $minClusterSeparation = 0.1) { //Ascending sort of edges while perserving original keys. $this->edges = $edges; - uasort($this->edges, static function ($a, $b) { - if ($a[1] > $b[1]) { + uasort($this->edges, function ($a, $b) { + if ($a["distance"] > $b["distance"]) { return 1; } - if ($a[1] < $b[1]) { + if ($a["distance"] < $b["distance"]) { return -1; } return 0; @@ -48,6 +35,16 @@ public function __construct(array $edges, int $minimumClusterSize, ?float $start $this->remainingEdges = $this->edges; + if ($mapVerticesToEdges === null) { + $mapVerticesToEdges = []; + foreach ($this->edges as $edgeIndex => $edge) { + $mapVerticesToEdges[$edge['vertexFrom']][$edgeIndex] = true; + $mapVerticesToEdges[$edge['vertexTo']][$edgeIndex] = true; + } + } + + $this->mapVerticesToEdges = $mapVerticesToEdges; + if (is_null($startingLambda)) { $this->isRoot = true; $this->startingLambda = 0.0; @@ -62,7 +59,7 @@ public function __construct(array $edges, int $minimumClusterSize, ?float $start $this->clusterWeight = 0.0; - $this->maxEdgeLength = $maxEdgeLength; + $this->minClusterSeparation = $minClusterSeparation; } @@ -74,39 +71,56 @@ public function processCluster(): array { $edgeCount = count($this->remainingEdges); if ($edgeCount < ($this->minimumClusterSize - 1)) { - if ($edgeLength > $this->maxEdgeLength) { - return []; + foreach ($this->coreEdges as &$edge) { + $edge['finalLambda'] = $currentLambda; } + unset($edge); - $this->coreEdges = $this->remainingEdges; + foreach (array_keys($this->remainingEdges) as $edgeKey) { + $this->edges[$edgeKey]['finalLambda'] = $currentLambda; + } return [$this]; } - $vertexConnectedTo = array_key_last($this->remainingEdges); + if ($edgeCount < (2 * $this->minimumClusterSize - 1)) { + // The end is near; this cluster cannot be split into two anymore + $this->coreEdges = $this->remainingEdges; + } + + $currentLongestEdgeKey = array_key_last($this->remainingEdges); $currentLongestEdge = array_pop($this->remainingEdges); - $vertexConnectedFrom = $currentLongestEdge[0]; - $edgeLength = $currentLongestEdge[1]; + $vertexConnectedFrom = $currentLongestEdge["vertexFrom"]; + $vertexConnectedTo = $currentLongestEdge["vertexTo"]; + $edgeLength = $currentLongestEdge["distance"]; - if ($edgeLength > $this->maxEdgeLength) { - // Prevent formation of clusters with edges longer than the maximum edge length - $currentLambda = $lastLambda = 1 / $edgeLength; - } elseif ($edgeLength > 0.0) { + unset($this->mapVerticesToEdges[$vertexConnectedFrom][$currentLongestEdgeKey]); + unset($this->mapVerticesToEdges[$vertexConnectedTo][$currentLongestEdgeKey]); + + if ($edgeLength > 0.0) { $currentLambda = 1 / $edgeLength; } $this->clusterWeight += ($currentLambda - $lastLambda) * $edgeCount; $lastLambda = $currentLambda; - if (!$this->pruneFromCluster($vertexConnectedTo) && !$this->pruneFromCluster($vertexConnectedFrom)) { + $this->edges[$currentLongestEdgeKey]["finalLambda"] = $currentLambda; + + if (!$this->pruneFromCluster($vertexConnectedTo, $currentLambda) && !$this->pruneFromCluster($vertexConnectedFrom, $currentLambda)) { // This cluster will (probably) split into two child clusters: - $childClusterEdges1 = $this->getChildClusterEdges($vertexConnectedTo); - $childClusterEdges2 = $this->getChildClusterEdges($vertexConnectedFrom); + [$childClusterEdges1, $childClusterVerticesToEdges1] = $this->getChildClusterComponents($vertexConnectedTo); + [$childClusterEdges2, $childClusterVerticesToEdges2] = $this->getChildClusterComponents($vertexConnectedFrom); if ($edgeLength < $this->minClusterSeparation) { - $this->remainingEdges = count($childClusterEdges1) > count($childClusterEdges2) ? $childClusterEdges1 : $childClusterEdges2; + if (count($childClusterEdges1) > count($childClusterEdges2)) { + $this->remainingEdges = $childClusterEdges1; + $this->mapVerticesToEdges = $childClusterVerticesToEdges1; + } else { + $this->remainingEdges = $childClusterEdges2; + $this->mapVerticesToEdges = $childClusterVerticesToEdges2; + } continue; } @@ -114,10 +128,10 @@ public function processCluster(): array { // Return a list of children if the weight of all children is more than $this->clusterWeight. // Otherwise return the current cluster and discard the children. This way we "choose" a combination // of clusters that weigh the most (i.e. have most (excess of) mass). Always discard the root cluster. - $finalLambda = $currentLambda; - $childCluster1 = new MstClusterer($childClusterEdges1, $this->minimumClusterSize, $finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); - $childCluster2 = new MstClusterer($childClusterEdges2, $this->minimumClusterSize, $finalLambda, $this->maxEdgeLength, $this->minClusterSeparation); + + $childCluster1 = new MstClusterer($childClusterEdges1, $childClusterVerticesToEdges1, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation); + $childCluster2 = new MstClusterer($childClusterEdges2, $childClusterVerticesToEdges2, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation); // Resolve all chosen child clusters recursively $childClusters = array_merge($childCluster1->processCluster(), $childCluster2->processCluster()); @@ -130,102 +144,126 @@ public function processCluster(): array { if (($childrenWeight > $this->clusterWeight) || $this->isRoot) { return $childClusters; + } else { + foreach (array_keys($this->remainingEdges) as $edgeKey) { + $this->edges[$edgeKey]['finalLambda'] = $currentLambda; + } } return [$this]; } - - if ($edgeLength > $this->maxEdgeLength) { - $this->edges = $this->remainingEdges; - } } } - private function pruneFromCluster(int $vertexId): bool { + private function pruneFromCluster(int $vertexId, float $currentLambda): bool { $edgeIndicesToPrune = []; + $verticesToPrune = []; $vertexStack = [$vertexId]; while (!empty($vertexStack)) { $currentVertex = array_pop($vertexStack); + $verticesToPrune[] = $currentVertex; - if (count($edgeIndicesToPrune) >= ($this->minimumClusterSize - 1)) { + if (count($verticesToPrune) >= $this->minimumClusterSize) { return false; } - // Traverse the MST edges backward - if (isset($this->remainingEdges[$currentVertex]) && !in_array($currentVertex, $edgeIndicesToPrune)) { - $incomingEdge = $this->remainingEdges[$currentVertex]; - $edgeIndicesToPrune[] = $currentVertex; - - $vertexStack[] = $incomingEdge[0]; - } + foreach (array_keys($this->mapVerticesToEdges[$currentVertex]) as $edgeKey) { + if (isset($edgeIndicesToPrune[$edgeKey])) { + continue; + } - // Traverse the MST edges forward - foreach ($this->remainingEdges as $key => $edge) { - if (($edge[0] == $currentVertex) && !in_array($key, $edgeIndicesToPrune)) { - $vertexStack[] = $key; - $edgeIndicesToPrune[] = $key; + if ($this->remainingEdges[$edgeKey]["vertexFrom"] === $currentVertex) { + $vertexStack[] = $this->remainingEdges[$edgeKey]["vertexTo"]; + $edgeIndicesToPrune[$edgeKey] = true; + } elseif ($this->remainingEdges[$edgeKey]["vertexTo"] === $currentVertex) { + $vertexStack[] = $this->remainingEdges[$edgeKey]["vertexFrom"]; + $edgeIndicesToPrune[$edgeKey] = true; } } } // Prune edges - foreach ($edgeIndicesToPrune as $edgeToPrune) { + foreach (array_keys($edgeIndicesToPrune) as $edgeToPrune) { + $this->edges[$edgeToPrune]['finalLambda'] = $currentLambda; unset($this->remainingEdges[$edgeToPrune]); } + // Prune vertices to edges map (not stricly necessary but saves some memory) + foreach ($verticesToPrune as $vertexLabel) { + unset($this->mapVerticesToEdges[$vertexLabel]); + } + return true; } - private function getChildClusterEdges(int $vertexId): array { + private function getChildClusterComponents(int $vertexId): array { $vertexStack = [$vertexId]; - $edgesInCluster = []; + $edgeIndicesInCluster = []; + $verticesInCluster = []; while (!empty($vertexStack)) { $currentVertex = array_pop($vertexStack); + $verticesInCluster[$currentVertex] = $this->mapVerticesToEdges[$currentVertex]; - // Traverse the MST edges backward - if (isset($this->remainingEdges[$currentVertex]) && !isset($edgesInCluster[$currentVertex])) { - $incomingEdge = $this->remainingEdges[$currentVertex]; - - //Edges are indexed by the vertex they're connected to - $edgesInCluster[$currentVertex] = $incomingEdge; + foreach (array_keys($this->mapVerticesToEdges[$currentVertex]) as $edgeKey) { + if (isset($edgeIndicesInCluster[$edgeKey])) { + continue; + } - $vertexStack[] = $incomingEdge[0]; + if ($this->remainingEdges[$edgeKey]["vertexFrom"] === $currentVertex) { + $vertexStack[] = $this->remainingEdges[$edgeKey]["vertexTo"]; + $edgeIndicesInCluster[$edgeKey] = true; + } elseif ($this->remainingEdges[$edgeKey]["vertexTo"] === $currentVertex) { + $vertexStack[] = $this->remainingEdges[$edgeKey]["vertexFrom"]; + $edgeIndicesInCluster[$edgeKey] = true; + } } + } - // Traverse the MST edges forward - foreach ($this->remainingEdges as $key => $edge) { - if ($edge[0] == $currentVertex && !isset($edgesInCluster[$key])) { - $vertexStack[] = $key; - $edgesInCluster[$key] = $edge; - } + // Collecting the edges is done in a separate loop to perserve the ordering according to length. + // (See constructor.) + $edgesInCluster = []; + foreach ($this->remainingEdges as $edgeKey => $edge) { + if (isset($edgeIndicesInCluster[$edgeKey])) { + $edgesInCluster[$edgeKey] = $edge; } } - return $edgesInCluster; + return [$edgesInCluster, $verticesInCluster]; } public function getClusterWeight(): float { return $this->clusterWeight; } + public function getClusterVertices(): array { + $vertices = []; - /** - * @returns list - */ - public function getVertexKeys(): array { - $vertexKeys = []; - - foreach ($this->edges as $key => $edge) { - $vertexKeys[] = $key; - $vertexKeys[] = $edge[0]; + foreach ($this->edges as $edge) { + $vertices[$edge["vertexTo"]] = min($edge["finalLambda"], $vertices[$edge["vertexTo"]] ?? INF); + $vertices[$edge["vertexFrom"]] = min($edge["finalLambda"], $vertices[$edge["vertexFrom"]] ?? INF); } - return array_unique($vertexKeys); + return $vertices; } public function getCoreEdges(): array { return $this->coreEdges; } + + public function getClusterEdges(): array { + return $this->edges; + } + + public function getCoreVertices(): array { + $vertices = []; + + foreach ($this->coreEdges as $edge) { + $vertices[$edge["vertexTo"]] = min($edge["finalLambda"], $vertices[$edge["vertexTo"]] ?? INF); + $vertices[$edge["vertexFrom"]] = min($edge["finalLambda"], $vertices[$edge["vertexFrom"]] ?? INF); + } + + return $vertices; + } } diff --git a/lib/Clustering/MstSolver.php b/lib/Clustering/MstSolver.php new file mode 100644 index 00000000..37dc1ceb --- /dev/null +++ b/lib/Clustering/MstSolver.php @@ -0,0 +1,271 @@ +kernel = $kernel ?? new SquaredDistance(); + + $this->tree = new MrdBallTree($maxLeafSize, $sampleSize, $this->kernel); + + $this->tree->grow($fullDataset); + $this->tree->precalculateCoreDistances($oldCoreDistances); + + $this->useTrueMst = $useTrueMst; + } + + public function kernel(): Distance { + return $this->kernel; + } + + public function getCoreNeighborDistances(): array { + return $this->tree->getCoreNeighborDistances(); + } + + public function getTree(): MrdBallTree { + return $this->tree; + } + + private function updateEdges($queryNode, $referenceNode, array &$newEdges, array &$vertexToSetId): void { + $querySamples = $queryNode->dataset()->samples(); + $queryLabels = $queryNode->dataset()->labels(); + $referenceSamples = $referenceNode->dataset()->samples(); + $referenceLabels = $referenceNode->dataset()->labels(); + + $longestDistance = 0.0; + $shortestDistance = INF; + + foreach ($querySamples as $queryKey => $querySample) { + $queryLabel = $queryLabels[$queryKey]; + $querySetId = $vertexToSetId[$queryLabel]; + + if ($this->tree->getCoreDistance($queryLabel) > ($newEdges[$querySetId]["distance"] ?? INF)) { + // The core distance of the current vertex is greater than the current best edge + // for this setId. This means that the MRD will always be greater than the current best. + continue; + } + + foreach ($referenceSamples as $referenceKey => $referenceSample) { + $referenceLabel = $referenceLabels[$referenceKey]; + $referenceSetId = $vertexToSetId[$referenceLabel]; + + if ($querySetId === $referenceSetId) { + continue; + } + + $distance = $this->tree->computeMrd($queryLabel, $querySample, $referenceLabel, $referenceSample); + + if ($distance < ($newEdges[$querySetId]["distance"] ?? INF)) { + $newEdges[$querySetId] = ["vertexFrom" => $queryLabel, "vertexTo" => $referenceLabel, "distance" => $distance]; + } + } + $candidateDist = $newEdges[$querySetId]["distance"] ?? INF; + if ($candidateDist > $longestDistance) { + $longestDistance = $candidateDist; + } + + if ($candidateDist < $shortestDistance) { + $shortestDistance = $candidateDist; + } + } + + // Update the bound of the query node + if ($this->kernel instanceof SquaredDistance) { + $longestDistance = min($longestDistance, (2 * sqrt($queryNode->radius()) + sqrt($shortestDistance)) ** 2); + } else { + $longestDistance = min($longestDistance, 2 * $queryNode->radius() + $shortestDistance); + } + + $queryNode->setLongestDistance($longestDistance); + } + + private function findSetNeighbors($queryNode, $referenceNode, array &$newEdges, array &$vertexToSetId): void { + if ($queryNode->isFullyConnected() && $referenceNode->isFullyConnected()) { + if ($queryNode->getSetId() === $referenceNode->getSetId()) { + // These nodes are connected and in the same set, so we can prune this reference node. + return; + } + } + + // if d(Q,R) > d(Q) then + // return; + + $nodeDistance = $this->tree->nodeDistance($queryNode, $referenceNode); + + if ($nodeDistance > 0.0) { + // Calculate smallest possible bound (i.e., d(Q) ): + if ($queryNode->isFullyConnected()) { + $currentBound = min($newEdges[$queryNode->getSetId()]["distance"] ?? INF, $queryNode->getLongestDistance()); + } else { + $currentBound = $queryNode->getLongestDistance(); + } + // If node distance is greater than the longest possible edge in this node, + // prune this reference node + if ($nodeDistance > $currentBound) { + return; + } + } + + if ($queryNode instanceof DualTreeClique && $referenceNode instanceof DualTreeClique) { + $this->updateEdges($queryNode, $referenceNode, $newEdges, $vertexToSetId); + return; + } + + if ($queryNode instanceof DualTreeClique) { + foreach ($referenceNode->children() as $child) { + $this->findSetNeighbors($queryNode, $child, $newEdges, $vertexToSetId); + } + return; + } + + if ($referenceNode instanceof DualTreeClique) { + $longestDistance = 0.0; + + $queryLeft = $queryNode->left(); + $queryRight = $queryNode->right(); + + $this->findSetNeighbors($queryLeft, $referenceNode, $newEdges, $vertexToSetId); + $this->findSetNeighbors($queryRight, $referenceNode, $newEdges, $vertexToSetId); + } else { // if ($queryNode instanceof DualTreeBall && $referenceNode instanceof DualTreeBall) + $queryLeft = $queryNode->left(); + $queryRight = $queryNode->right(); + $referenceLeft = $referenceNode->left(); + $referenceRight = $referenceNode->right(); + + $this->findSetNeighbors($queryLeft, $referenceLeft, $newEdges, $vertexToSetId); + $this->findSetNeighbors($queryRight, $referenceRight, $newEdges, $vertexToSetId); + $this->findSetNeighbors($queryLeft, $referenceRight, $newEdges, $vertexToSetId); + $this->findSetNeighbors($queryRight, $referenceLeft, $newEdges, $vertexToSetId); + } + + $longestLeft = $queryLeft->getLongestDistance(); + $longestRight = $queryRight->getLongestDistance(); + + // TODO: min($longestLeft, $longestRight) + 2 * ($queryNode->radius()) <--- Can be made tighter? + if ($this->kernel instanceof SquaredDistance) { + $longestDistance = max($longestLeft, $longestRight); + $longestLeft = (sqrt($longestLeft) + 2 * (sqrt($queryNode->radius()) - sqrt($queryLeft->radius()))) ** 2; + $longestRight = (sqrt($longestRight) + 2 * (sqrt($queryNode->radius()) - sqrt($queryRight->radius()))) ** 2; + $longestDistance = min($longestDistance, min($longestLeft, $longestRight), (sqrt(min($longestLeft, $longestRight)) + 2 * (sqrt($queryNode->radius()))) ** 2); + } else { + $longestDistance = max($longestLeft, $longestRight); + $longestLeft = $longestLeft + 2 * ($queryNode->radius() - $queryLeft->radius()); + $longestRight = $longestRight + 2 * ($queryNode->radius() - $queryRight->radius()); + $longestDistance = min($longestDistance, min($longestLeft, $longestRight), min($longestLeft, $longestRight) + 2 * ($queryNode->radius())); + } + + $queryNode->setLongestDistance($longestDistance); + + return; + } + + public function getMst(): array { + $edges = []; + + // MST generation using dual-tree boruvka algorithm + + $treeRoot = $this->tree->getRoot(); + + $treeRoot->resetFullyConnectedStatus(); + + $allLabels = $this->tree->getDataset()->labels(); + + $vertexToSetId = array_combine($allLabels, range(0, count($allLabels) - 1)); + + $vertexSets = []; + foreach ($vertexToSetId as $vertex => $setId) { + $vertexSets[$setId] = [$vertex]; + } + + if (!$this->useTrueMst) { + $treeRoot->resetLongestEdge(); + } + + // Use nearest neighbors known from determining core distances for each vertex to + // get the first set of $newEdges (we essentially can skip the first round of Boruvka): + $newEdges = []; + + foreach ($allLabels as $label) { + [$coreNeighborLabels, $coreNeighborDistances] = $this->tree->getCoreNeighbors($label); + + $coreDistance = end($coreNeighborDistances); + + foreach ($coreNeighborLabels as $neighborLabel) { + if ($neighborLabel === $label) { + continue; + } + + if ($this->tree->getCoreDistance($neighborLabel) <= $coreDistance) { + // This point is our nearest neighbor in mutual reachability terms, so + // an edge spanning these vertices will belong to the MST. + $newEdges[] = ["vertexFrom" => $label, "vertexTo" => $neighborLabel, "distance" => $coreDistance]; + break; + } + } + } + + /////////////////////////////////////////////////////////////////////////////// + // Main dual tree Boruvka loop: + + while (true) { + //Add new edges + //Update vertex to set/set to vertex mappings + foreach ($newEdges as $connectingEdge) { + $setId1 = $vertexToSetId[$connectingEdge["vertexFrom"]]; + $setId2 = $vertexToSetId[$connectingEdge["vertexTo"]]; + + if ($setId1 === $setId2) { + // These sets have already been merged earlier in this loop + continue; + } + + $edges[] = $connectingEdge; + + if (count($vertexSets[$setId1]) < count($vertexSets[$setId2])) { + // Make a switch such that the larger set is always Id1 + [$setId1, $setId2] = [$setId2, $setId1]; + } + + // Assign all vertices in set 2 to set 1 + foreach ($vertexSets[$setId2] as $vertexLabel) { + $vertexToSetId[$vertexLabel] = $setId1; + } + + $vertexSets[$setId1] = array_merge($vertexSets[$setId1], $vertexSets[$setId2]); + unset($vertexSets[$setId2]); + } + + // Check for exit condition + if (count($vertexSets) === 1) { + break; + } + + //Update the tree + if ($this->useTrueMst || empty($newEdges)) { + $treeRoot->resetLongestEdge(); + } + + if (!empty($newEdges)) { + $treeRoot->propagateSetChanges($vertexToSetId); + } + + // Clear the array for a set of new edges + $newEdges = []; + + $this->findSetNeighbors($treeRoot, $treeRoot, $newEdges, $vertexToSetId); + } + + return $edges; + } +} diff --git a/lib/Clustering/SquaredDistance.php b/lib/Clustering/SquaredDistance.php new file mode 100644 index 00000000..6b4fda94 --- /dev/null +++ b/lib/Clustering/SquaredDistance.php @@ -0,0 +1,64 @@ + + */ + public function compatibility(): array { + return [ + DataType::continuous(), + ]; + } + + /** + * Compute the distance between two vectors. + * + * @internal + * + * @param list $a + * @param list $b + * @return float + */ + public function compute(array $a, array $b): float { + $distance = 0.0; + + foreach ($a as $i => $value) { + $distance += ($value - $b[$i]) ** 2; + } + + return $distance; + } + + /** + * Return the string representation of the object. + * + * @internal + * + * @return string + */ + public function __toString(): string { + return 'Squared distance'; + } +} diff --git a/lib/Db/FaceDetectionMapper.php b/lib/Db/FaceDetectionMapper.php index 4b884061..960d230c 100644 --- a/lib/Db/FaceDetectionMapper.php +++ b/lib/Db/FaceDetectionMapper.php @@ -9,12 +9,16 @@ use OCP\AppFramework\Db\Entity; use OCP\AppFramework\Db\QBMapper; use OCP\DB\QueryBuilder\IQueryBuilder; +use OCP\IConfig; use OCP\IDBConnection; class FaceDetectionMapper extends QBMapper { - public function __construct(IDBConnection $db) { + private IConfig $config; + + public function __construct(IDBConnection $db, IConfig $config) { parent::__construct($db, 'recognize_face_detections', FaceDetection::class); $this->db = $db; + $this->config = $config; } /** @@ -119,6 +123,36 @@ public function findByFileIdAndClusterId(int $fileId, int $clusterId) : FaceDete return $this->findEntity($qb); } + /** + * @param string $userId + * @return \OCA\Recognize\Db\FaceDetection[] + * @throws \OCP\DB\Exception + */ + public function findUnclusteredByUserId(string $userId) : array { + $qb = $this->db->getQueryBuilder(); + $qb->select(FaceDetection::$columns) + ->from('recognize_face_detections') + ->where($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId))) + ->andWhere($qb->expr()->isNull('cluster_id')); + return $this->findEntities($qb); + } + + public function findClusterSample(int $clusterId, int $n) { + $qb = $this->db->getQueryBuilder(); + $qb->select(FaceDetection::$columns) + ->from('recognize_face_detections', 'd') + ->where($qb->expr()->eq('cluster_id', $qb->createPositionalParameter($clusterId))) + ->orderBy( + $qb->createFunction( + $this->config->getSystemValue('dbtype', 'sqlite') === 'mysql' + ? 'RAND()' + : 'RANDOM()' + ) + ) + ->setMaxResults($n); + return $this->findEntities($qb); + } + protected function mapRowToEntity(array $row): Entity { try { return parent::mapRowToEntity($row); diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index 13054760..e06ed0a8 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -6,8 +6,7 @@ namespace OCA\Recognize\Service; -use OCA\Recognize\Clustering\MRDistance; -use OCA\Recognize\Clustering\MstClusterer; +use OCA\Recognize\Clustering\HDBSCAN; use OCA\Recognize\Db\FaceCluster; use OCA\Recognize\Db\FaceClusterMapper; use OCA\Recognize\Db\FaceDetection; @@ -16,15 +15,12 @@ use Rubix\ML\Kernels\Distance\Euclidean; class FaceClusterAnalyzer { + public const MIN_DATASET_SIZE = 30; public const MIN_SAMPLE_SIZE = 4; // Conservative value: 10 public const MIN_CLUSTER_SIZE = 5; // Conservative value: 10 - public const MAX_CLUSTER_EDGE_LENGHT = 99.0; - public const MIN_CLUSTER_SEPARATION = 0.0; - // For incremental clustering - public const MAX_INNER_CLUSTER_RADIUS = 0.44; public const MIN_DETECTION_SIZE = 0.03; - public const DIMENSIONS = 128; + public const SAMPLE_SIZE_EXISTING_CLUSTERS = 42; private FaceDetectionMapper $faceDetections; private FaceClusterMapper $faceClusters; @@ -43,86 +39,54 @@ public function __construct(FaceDetectionMapper $faceDetections, FaceClusterMapp public function calculateClusters(string $userId): void { $this->logger->debug('ClusterDebug: Retrieving face detections for user ' . $userId); - $detections = $this->faceDetections->findByUserId($userId); + $unclusteredDetections = $this->faceDetections->findUnclusteredByUserId($userId); - $detections = array_values(array_filter($detections, fn ($detection) => + $unclusteredDetections = array_values(array_filter($unclusteredDetections, fn ($detection) => $detection->getHeight() > self::MIN_DETECTION_SIZE && $detection->getWidth() > self::MIN_DETECTION_SIZE )); - $unclusteredDetections = $this->assignToExistingClusters($userId, $detections); - - if (count($unclusteredDetections) < max(self::MIN_SAMPLE_SIZE, self::MIN_CLUSTER_SIZE)) { + if (count($unclusteredDetections) < self::MIN_DATASET_SIZE) { $this->logger->debug('ClusterDebug: Not enough face detections found'); return; } $this->logger->debug('ClusterDebug: Found ' . count($unclusteredDetections) . " unclustered detections. Calculating clusters."); - $dataset = new Labeled(array_map(static function (FaceDetection $detection): array { - return $detection->getVector(); - }, $unclusteredDetections), array_combine(array_keys($unclusteredDetections), array_keys($unclusteredDetections)), false); - - $distanceKernel = new MRDistance(self::MIN_SAMPLE_SIZE, $dataset, new Euclidean()); - - $primsStartTime = microtime(true);// DEBUG + $sampledDetections = []; - // Prim's algorithm: - - $unconnectedVertices = array_combine(array_keys($unclusteredDetections), array_keys($unclusteredDetections)); - - $firstVertex = current($unconnectedVertices); - $firstVertexVector = $dataset->sample($firstVertex); - unset($unconnectedVertices[$firstVertex]); - - $edges = []; - foreach ($unconnectedVertices as $vertex) { - $edges[$vertex] = [$firstVertex, $distanceKernel->distance($firstVertex, $firstVertexVector, $vertex, $dataset->sample($vertex))]; + $existingClusters = $this->faceClusters->findByUserId($userId); + foreach ($existingClusters as $existingCluster) { + $sampled = $this->faceDetections->findClusterSample($existingCluster->getId(), self::SAMPLE_SIZE_EXISTING_CLUSTERS); + $sampledDetections = array_merge($sampledDetections, $sampled); } - while (count($unconnectedVertices) > 0) { - $minDistance = INF; - $minVertex = null; - - foreach ($unconnectedVertices as $vertex) { - $distance = $edges[$vertex][1]; - if ($distance < $minDistance) { - $minDistance = $distance; - $minVertex = $vertex; - } - } + $detections = array_merge($unclusteredDetections, $sampledDetections); - unset($unconnectedVertices[$minVertex]); - $minVertexVector = $dataset->sample($minVertex); - - foreach ($unconnectedVertices as $vertex) { - $distance = $distanceKernel->distance($minVertex, $minVertexVector, $vertex, $dataset->sample($vertex)); - if ($edges[$vertex][1] > $distance) { - $edges[$vertex] = [$minVertex,$distance]; - } - } - } - - $executionTime = (microtime(true) - $primsStartTime);// DEBUG - $this->logger->debug('ClusterDebug: Prims algo took '.$executionTime." secs.");// DEBUG - - // Calculate the face clusters based on the minimum spanning tree. + $dataset = new Labeled(array_map(static function (FaceDetection $detection): array { + return $detection->getVector(); + }, $detections), array_combine(array_keys($detections), array_keys($detections)), false); - $mstClusterer = new MstClusterer($edges, self::MIN_CLUSTER_SIZE, null, self::MAX_CLUSTER_EDGE_LENGHT, self::MIN_CLUSTER_SEPARATION); - $flatClusters = $mstClusterer->processCluster(); + $hdbscan = new HDBSCAN($dataset, self::MIN_CLUSTER_SIZE, self::MIN_SAMPLE_SIZE); $numberOfClusteredDetections = 0; + $clusters = $hdbscan->predict(); - foreach ($flatClusters as $flatCluster) { + foreach ($clusters as $flatCluster) { $cluster = new FaceCluster(); $cluster->setTitle(''); $cluster->setUserId($userId); $this->faceClusters->insert($cluster); - $detectionKeys = $flatCluster->getVertexKeys(); + $detectionKeys = $flatCluster->getClusterVertices(); $clusterCentroid = self::calculateCentroidOfDetections(array_map(static fn ($key) => $unclusteredDetections[$key], $detectionKeys)); foreach ($detectionKeys as $detectionKey) { + if ($detectionKey >= count($unclusteredDetections)) { + // This is a sampled, already clustered detection, ignore. + continue; + } + // If threshold is larger than 0 and $clusterCentroid is not the null vector if ($unclusteredDetections[$detectionKey]->getThreshold() > 0.0 && count(array_filter($clusterCentroid, fn ($el) => $el !== 0.0)) > 0) { // If a threshold is set for this detection and its vector is farther away from the centroid From 56c24756e2509d59d56f108d435822e85b0c88da Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Sun, 5 Feb 2023 18:28:13 +0100 Subject: [PATCH 11/17] Fix array index Signed-off-by: Marcel Klehr --- lib/Service/FaceClusterAnalyzer.php | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index e06ed0a8..dbb76339 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -78,7 +78,7 @@ public function calculateClusters(string $userId): void { $this->faceClusters->insert($cluster); $detectionKeys = $flatCluster->getClusterVertices(); - $clusterCentroid = self::calculateCentroidOfDetections(array_map(static fn ($key) => $unclusteredDetections[$key], $detectionKeys)); + $clusterCentroid = self::calculateCentroidOfDetections(array_map(static fn ($key) => $detections[$key], $detectionKeys)); foreach ($detectionKeys as $detectionKey) { From b6f1b59962c8a9bff3ef14a56f360c490223916a Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Sun, 5 Feb 2023 19:47:03 +0100 Subject: [PATCH 12/17] Fix array index Signed-off-by: Marcel Klehr --- lib/Service/FaceClusterAnalyzer.php | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index dbb76339..b5c16019 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -77,7 +77,7 @@ public function calculateClusters(string $userId): void { $cluster->setUserId($userId); $this->faceClusters->insert($cluster); - $detectionKeys = $flatCluster->getClusterVertices(); + $detectionKeys = array_keys($flatCluster->getClusterVertices()); $clusterCentroid = self::calculateCentroidOfDetections(array_map(static fn ($key) => $detections[$key], $detectionKeys)); From a60fcaf1492f57f424b8cdff57c19680721dbbe2 Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Sun, 5 Feb 2023 20:22:46 +0100 Subject: [PATCH 13/17] Fix incremental clustering Signed-off-by: Marcel Klehr --- lib/Service/FaceClusterAnalyzer.php | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index b5c16019..c04b1729 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -72,14 +72,23 @@ public function calculateClusters(string $userId): void { $clusters = $hdbscan->predict(); foreach ($clusters as $flatCluster) { - $cluster = new FaceCluster(); - $cluster->setTitle(''); - $cluster->setUserId($userId); - $this->faceClusters->insert($cluster); - $detectionKeys = array_keys($flatCluster->getClusterVertices()); $clusterCentroid = self::calculateCentroidOfDetections(array_map(static fn ($key) => $detections[$key], $detectionKeys)); + /** + * @var FaceDetection + */ + $detection = current(array_filter($detectionKeys, fn ($key) => $detections[$key]->getClusterId() !== null)); + $clusterId = $detection->getClusterId(); + + if ($clusterId !== null) { + $cluster = $this->faceClusters->find($clusterId); + } else { + $cluster = new FaceCluster(); + $cluster->setTitle(''); + $cluster->setUserId($userId); + $this->faceClusters->insert($cluster); + } foreach ($detectionKeys as $detectionKey) { if ($detectionKey >= count($unclusteredDetections)) { From cf268d45e8d370e61ecd8c18621e4c7f4dad953b Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Sun, 5 Feb 2023 20:44:03 +0100 Subject: [PATCH 14/17] Fix tests Signed-off-by: Marcel Klehr --- tests/ClusterTest.php | 45 ++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/tests/ClusterTest.php b/tests/ClusterTest.php index e7cff954..2b1cd55c 100644 --- a/tests/ClusterTest.php +++ b/tests/ClusterTest.php @@ -4,7 +4,9 @@ use OCA\Recognize\Db\FaceDetection; use OCA\Recognize\Db\FaceDetectionMapper; use OCA\Recognize\Service\FaceClusterAnalyzer; +use OCA\Recognize\Service\Logger; use Rubix\ML\Kernels\Distance\Euclidean; +use Symfony\Component\Console\Output\OutputInterface; use Test\TestCase; /** @@ -31,6 +33,12 @@ public function setUp(): void { $this->faceClusterAnalyzer = \OC::$server->get(FaceClusterAnalyzer::class); $this->faceClusterMapper = \OC::$server->get(FaceClusterMapper::class); + $logger = \OC::$server->get(Logger::class); + $cliOutput = $this->createMock(OutputInterface::class); + $cliOutput->method('writeln') + ->willReturnCallback(fn ($msg) => print($msg."\n")); + $logger->setCliOutput($cliOutput); + $clusters = $this->faceClusterMapper->findByUserId(self::TEST_USER1); foreach ($clusters as $cluster) { $this->faceClusterMapper->delete($cluster); @@ -193,15 +201,17 @@ public function testClusterMerging2() { $this->faceDetectionMapper->update($detection); } - $newDetection = new FaceDetection(); - $newDetection->setHeight(0.5); - $newDetection->setWidth(0.5); - $newDetection->setFileId(500000); - $nullVector = self::getNullVector(); - $nullVector[0] = 0.8; - $newDetection->setVector($nullVector); - $newDetection->setUserId(self::TEST_USER1); - $this->faceDetectionMapper->insert($newDetection); + for ($i = 0; $i < self::INITIAL_DETECTIONS_PER_CLUSTER; $i++) { + $newDetection = new FaceDetection(); + $newDetection->setHeight(0.5); + $newDetection->setWidth(0.5); + $newDetection->setFileId(500000+$i); + $nullVector = self::getNullVector(); + $nullVector[0] = 1 + 0.001 * $i; + $newDetection->setVector($nullVector); + $newDetection->setUserId(self::TEST_USER1); + $this->faceDetectionMapper->insert($newDetection); + } $this->faceClusterAnalyzer->calculateClusters(self::TEST_USER1); @@ -210,7 +220,7 @@ public function testClusterMerging2() { self::assertCount(1, $clusters); $detections = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER * 2 + 1, $detections); + self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER * 3, $detections); } /** @@ -238,7 +248,7 @@ public function testClusterMergingWithThirdAdditionalCluster() { $detection->setClusterId($clusters[0]->getId()); $this->faceDetectionMapper->update($detection); } - $this->faceClusterMapper->delete($clusters[0]); + $this->faceClusterMapper->delete($clusters[1]); $numOfDetections = self::INITIAL_DETECTIONS_PER_CLUSTER; $clusterValue = 3 * self::INITIAL_DETECTIONS_PER_CLUSTER; @@ -263,9 +273,8 @@ public function testClusterMergingWithThirdAdditionalCluster() { $detections1 = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); $detections2 = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); $counts = [count($detections1), count($detections2)]; - - self::assertCount(1, array_filter($counts, fn($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER)); - self::assertCount(1, array_filter($counts, fn($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER * 2)); + self::assertCount(1, array_filter($counts, fn ($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER), var_export($counts, true)); + self::assertCount(1, array_filter($counts, fn ($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER * 2), var_export($counts, true)); } /** @@ -325,15 +334,15 @@ public function testClusterTemptClusterMerging() { /** @var \OCA\Recognize\Db\FaceCluster[] $clusters */ $clusters = $this->faceClusterMapper->findByUserId(self::TEST_USER1); - self::assertCount(3, $clusters); + self::assertCount(4, $clusters); $detections1 = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); $detections2 = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); $detections3 = $this->faceDetectionMapper->findByClusterId($clusters[2]->getId()); - $counts = [count($detections1), count($detections2), count($detections3)]; + $detections4 = $this->faceDetectionMapper->findByClusterId($clusters[3]->getId()); + $counts = [count($detections1), count($detections2), count($detections3), count($detections4)]; - self::assertCount(2, array_filter($counts, fn($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER)); - self::assertCount(1, array_filter($counts, fn($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER * 2)); + self::assertCount(4, array_filter($counts, fn ($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER), var_export($counts, true)); } private static function getNullVector() { From bc0bc7e16318cbcb069e72902254c7b75b1d5f41 Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Sun, 5 Feb 2023 20:44:17 +0100 Subject: [PATCH 15/17] Fix Clustering Signed-off-by: Marcel Klehr --- lib/Service/FaceClusterAnalyzer.php | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index c04b1729..a18bf311 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -76,12 +76,12 @@ public function calculateClusters(string $userId): void { $clusterCentroid = self::calculateCentroidOfDetections(array_map(static fn ($key) => $detections[$key], $detectionKeys)); /** - * @var FaceDetection + * @var int|false $detection */ - $detection = current(array_filter($detectionKeys, fn ($key) => $detections[$key]->getClusterId() !== null)); - $clusterId = $detection->getClusterId(); + $detection = current(array_filter($detectionKeys, static fn ($key) => $detections[$key]->getClusterId() !== null)); - if ($clusterId !== null) { + if ($detection !== false) { + $clusterId = $detections[$detection]->getClusterId(); $cluster = $this->faceClusters->find($clusterId); } else { $cluster = new FaceCluster(); From 79d1c4857725aa50e94001e08e248e0c92caf304 Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Sun, 5 Feb 2023 21:13:26 +0100 Subject: [PATCH 16/17] Fix lint errors Signed-off-by: Marcel Klehr --- tests/ClusterTest.php | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ClusterTest.php b/tests/ClusterTest.php index 2b1cd55c..2ff11eb0 100644 --- a/tests/ClusterTest.php +++ b/tests/ClusterTest.php @@ -205,7 +205,7 @@ public function testClusterMerging2() { $newDetection = new FaceDetection(); $newDetection->setHeight(0.5); $newDetection->setWidth(0.5); - $newDetection->setFileId(500000+$i); + $newDetection->setFileId(500000 + $i); $nullVector = self::getNullVector(); $nullVector[0] = 1 + 0.001 * $i; $newDetection->setVector($nullVector); From 75586acd57a777dd0bdb4a12eb52bdd87a49d1aa Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Mon, 6 Feb 2023 14:08:39 +0100 Subject: [PATCH 17/17] chore(psalm): Fix psalm errors Signed-off-by: Marcel Klehr --- lib/Clustering/DualTreeBall.php | 11 +- lib/Clustering/DualTreeClique.php | 18 +- lib/Clustering/HDBSCAN.php | 10 +- lib/Clustering/MrdBallTree.php | 25 +- lib/Clustering/MstClusterer.php | 5 +- lib/Clustering/MstSolver.php | 10 + lib/Service/FaceClusterAnalyzer.php | 51 -- psalm-baseline.xml | 761 +++++++++++++++++++++++++--- 8 files changed, 749 insertions(+), 142 deletions(-) diff --git a/lib/Clustering/DualTreeBall.php b/lib/Clustering/DualTreeBall.php index cfcbf799..a6e6297d 100644 --- a/lib/Clustering/DualTreeBall.php +++ b/lib/Clustering/DualTreeBall.php @@ -13,8 +13,8 @@ use function Rubix\ML\argmax; class DualTreeBall extends Ball { - protected float $longestDistanceInNode; - protected bool $fullyConnected; + protected float $longestDistanceInNode = INF; + protected bool $fullyConnected = false; protected $setId; @@ -42,6 +42,9 @@ public function resetFullyConnectedStatus(): void { } } + /** + * @return null|int|string + */ public function getSetId() { if (!$this->fullyConnected) { return null; @@ -53,6 +56,10 @@ public function isFullyConnected(): bool { return $this->fullyConnected; } + /** + * @param array $labelToSetId + * @return null|int|string + */ public function propagateSetChanges(array &$labelToSetId) { if ($this->fullyConnected) { // If we are already fully connected, we just need to check if the diff --git a/lib/Clustering/DualTreeClique.php b/lib/Clustering/DualTreeClique.php index 910f0311..8f5d34e7 100644 --- a/lib/Clustering/DualTreeClique.php +++ b/lib/Clustering/DualTreeClique.php @@ -13,11 +13,14 @@ use function Rubix\ML\argmax; class DualTreeClique extends Clique { - protected float $longestDistanceInNode; - protected bool $fullyConnected; + protected float $longestDistanceInNode = INF; + protected bool $fullyConnected = false; + /** + * @var null|int|string + */ protected $setId; - public function setLongestDistance($longestDistance): void { + public function setLongestDistance(float $longestDistance): void { $this->longestDistanceInNode = $longestDistance; } @@ -33,6 +36,9 @@ public function resetFullyConnectedStatus(): void { $this->fullyConnected = false; } + /** + * @return int|string|null + */ public function getSetId() { if (!$this->fullyConnected) { return null; @@ -45,6 +51,10 @@ public function isFullyConnected(): bool { return $this->fullyConnected; } + /** + * @param array $labelToSetId + * @return int|mixed|string|null + */ public function propagateSetChanges(array &$labelToSetId) { if ($this->fullyConnected) { $this->setId = $labelToSetId[$this->dataset->label(0)]; @@ -53,8 +63,6 @@ public function propagateSetChanges(array &$labelToSetId) { $labels = $this->dataset->labels(); - $label = - $setId = $labelToSetId[array_pop($labels)]; foreach ($labels as $label) { diff --git a/lib/Clustering/HDBSCAN.php b/lib/Clustering/HDBSCAN.php index 9d515198..b253c29f 100644 --- a/lib/Clustering/HDBSCAN.php +++ b/lib/Clustering/HDBSCAN.php @@ -41,13 +41,6 @@ class HDBSCAN { */ protected int $minClusterSize; - /** - * The maximum length edge allowed within a cluster. - * - * @var float - */ - protected float $maxEdgeLength; - /** * The number of neighbors used for determining core distance when * calculating mutual reachability distance between points. @@ -92,6 +85,7 @@ public function __construct(Labeled $dataset, int $minClusterSize = 5, int $samp } $kernel = $kernel ?? new SquaredDistance(); + $this->sampleSize = $sampleSize; $this->minClusterSize = $minClusterSize; $this->mstSolver = new MstSolver($dataset, 20, $sampleSize, $kernel, $oldCoreDistances, $useTrueMst); } @@ -134,7 +128,7 @@ public function params(): array { /** * Form clusters and make predictions from the dataset (hard clustering). * - * @return list//@return list + * @return list */ public function predict(): array { // Boruvka algorithm for MST generation diff --git a/lib/Clustering/MrdBallTree.php b/lib/Clustering/MrdBallTree.php index 1d738431..849c0f15 100644 --- a/lib/Clustering/MrdBallTree.php +++ b/lib/Clustering/MrdBallTree.php @@ -13,10 +13,10 @@ use Rubix\ML\Kernels\Distance\Distance; class MrdBallTree extends BallTree { - private Labeled $dataset; - private array $nativeInterpointCache; - private array $coreDistances; - private array $coreNeighborDistances; + private ?Labeled $dataset = null; + private array $nativeInterpointCache = []; + private array $coreDistances = []; + private array $coreNeighborDistances = []; private int $sampleSize; private array $nodeDistances; private \SplObjectStorage $nodeIds; @@ -48,6 +48,11 @@ public function getCoreNeighborDistances(): array { return $this->coreNeighborDistances; } + /** + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $queryNode + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $referenceNode + * @return float + */ public function nodeDistance($queryNode, $referenceNode): float { // Use cache to accelerate repeated queries if ($this->nodeIds->contains($queryNode)) { @@ -111,6 +116,14 @@ public function getDataset(): Labeled { return $this->dataset; } + /** + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $queryNode + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $referenceNode + * @param int $k + * @param float $maxRange + * @param array $bestDistances + * @return void + */ private function updateNearestNeighbors($queryNode, $referenceNode, $k, $maxRange, &$bestDistances): void { $querySamples = $queryNode->dataset()->samples(); $queryLabels = $queryNode->dataset()->labels(); @@ -244,6 +257,10 @@ public function kNearestAll($k, float $maxRange = INF): void { $bestDistances[$label] = $maxRange; } + if ($this->root === null) { + return; + } + $treeRoot = $this->root; $treeRoot->resetFullyConnectedStatus(); diff --git a/lib/Clustering/MstClusterer.php b/lib/Clustering/MstClusterer.php index 849f5d3e..2f77b5ab 100644 --- a/lib/Clustering/MstClusterer.php +++ b/lib/Clustering/MstClusterer.php @@ -10,6 +10,10 @@ // TODO: store vertex lambda length (relative to cluster lambda length) for all vertices for soft clustering. class MstClusterer { private array $edges; + + /** + * @var list + */ private array $remainingEdges; private float $startingLambda; private float $clusterWeight; @@ -65,7 +69,6 @@ public function __construct(array $edges, ?array $mapVerticesToEdges, int $minim public function processCluster(): array { $currentLambda = $lastLambda = $this->startingLambda; - $edgeLength = INF; while (true) { $edgeCount = count($this->remainingEdges); diff --git a/lib/Clustering/MstSolver.php b/lib/Clustering/MstSolver.php index 37dc1ceb..d88a3383 100644 --- a/lib/Clustering/MstSolver.php +++ b/lib/Clustering/MstSolver.php @@ -90,6 +90,13 @@ private function updateEdges($queryNode, $referenceNode, array &$newEdges, array $queryNode->setLongestDistance($longestDistance); } + /** + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $queryNode + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $referenceNode + * @param array $newEdges + * @param array $vertexToSetId + * @return void + */ private function findSetNeighbors($queryNode, $referenceNode, array &$newEdges, array &$vertexToSetId): void { if ($queryNode->isFullyConnected() && $referenceNode->isFullyConnected()) { if ($queryNode->getSetId() === $referenceNode->getSetId()) { @@ -170,6 +177,9 @@ private function findSetNeighbors($queryNode, $referenceNode, array &$newEdges, return; } + /** + * @return list + */ public function getMst(): array { $edges = []; diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index a18bf311..c4683583 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -206,57 +206,6 @@ private function findFilesWithDuplicateFaces(array $detections): array { return $filesWithDuplicateFaces; } - /** - * @param string $userId - * @param list $detections - * @return list - * @throws \OCP\DB\Exception - */ - private function assignToExistingClusters(string $userId, array $detections): array { - $clusters = $this->faceClusters->findByUserId($userId); - - if (count($clusters) === 0) { - return $detections; - } - - $unclusteredDetections = []; - - foreach ($detections as $detection) { - $bestCluster = null; - $bestClusterDistance = 999; - if ($detection->getClusterId() !== null) { - continue; - } - foreach ($clusters as $cluster) { - $clusterDetections = $this->faceDetections->findByClusterId($cluster->getId()); - if (count($clusterDetections) > 50) { - $clusterDetections = array_map(fn ($key) => $clusterDetections[$key], array_rand($clusterDetections, 50)); - } - $clusterCentroid = self::calculateCentroidOfDetections($clusterDetections); - if ($detection->getThreshold() > 0 && self::distance($clusterCentroid, $detection->getVector()) >= $detection->getThreshold()) { - continue; - } - foreach ($clusterDetections as $clusterDetection) { - $distance = self::distance($clusterDetection->getVector(), $detection->getVector()); - if ( - $distance <= self::MAX_INNER_CLUSTER_RADIUS - && (!isset($bestCluster) || $distance < $bestClusterDistance) - ) { - $bestCluster = $cluster; - $bestClusterDistance = self::distance($clusterDetection->getVector(), $detection->getVector()); - break; - } - } - } - if ($bestCluster !== null) { - $this->faceDetections->assocWithCluster($detection, $bestCluster); - continue; - } - $unclusteredDetections[] = $detection; - } - return $unclusteredDetections; - } - private static ?Euclidean $distance; /** diff --git a/psalm-baseline.xml b/psalm-baseline.xml index 193c28ca..8672ec9d 100644 --- a/psalm-baseline.xml +++ b/psalm-baseline.xml @@ -185,78 +185,709 @@ getFileId - + + + $distances + + + Ball + Stats::mean($values) + + + Stats::mean($values) + compute + compute + isContinuous + new self($center, $radius, $subsets) + spatialSplit + + + $longestDistance + + + $setId + + + $values + $values + + + $column + $leftCentroid + $leftCentroid + $rightCentroid + $sample + $sample + + + $retVal + $setId + $this->longestDistanceInNode + $values + + + null|int|string + null|int|string + + + $this->setId + $this->setId + $this->setId + + + propagateSetChanges + propagateSetChanges + resetFullyConnectedStatus + resetLongestEdge + + + + + $distances + + + Clique + Stats::mean($values) + + + Stats::mean($values) + compute + isContinuous + new self($dataset, $center, $radius) + + + $values + $values + + + $column + $sample + + + $labelToSetId[$label] + $labelToSetId[array_pop($labels)] + + + $labelToSetId[array_pop($labels)] + + + $label + $values + + + $labelToSetId[$this->dataset->label(0)] + + + + + EstimatorType::clusterer() + + EstimatorType::clusterer() + Params::stringify($this->params()) + compatibility + + + $flatClusters + list<MstClusterer> + + + $dataset + + + + + center + center + center + center + cleanup compute - grow - nearest + compute + compute + compute + dataset + dataset + dataset + dataset + dataset + dataset + isPoint + radius + radius + radius + radius + radius + radius + radius + radius + radius + radius + radius + subsets - - $this->coreDistances[$index] - - - float - - - $this->dataset->sample($index) + + Labeled + + + $this->root + [$labels, $distances] + [$labels, $distances] + [array_keys($this->coreNeighborDistances[$sampleLabel]), array_values($this->coreNeighborDistances[$sampleLabel])] + + + $a + $b + $bestDistances + $k + $k + $label + $maxRange + $queryNode + $referenceNode + + + $bestDistances + $coreNeighborDistances + $coreNeighborDistances + $k + $label + $label + $largestOldCoreDistance + $longestDistance + $longestLeft + $longestRight + $maxRange + $oldDistances + $queryLeft->radius() + $queryNode + $queryNode->radius() + $queryNode->radius() + $queryNode->radius() + $queryRight->radius() + $querySample + $referenceNode + $referenceSample + $shortestDistance + $this->coreNeighborDistances[$label] + $this->coreNeighborDistances[$label] + $this->coreNeighborDistances[$label] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$sampleLabel] + $this->coreNeighborDistances[$sampleLabel] + reset($oldCoreNeighbors) + + + $a_vector + $b_vector + $sample + $sample + $sampleKey + $sampleKey + array_slice($neighborLabels, 0, $this->sampleSize) - - $this->coreDistances - + + $queryLabels[$queryKey] + $referenceLabels[$referenceKey] + + + $coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$neighborLabel][$label] + $this->coreNeighborDistances[$queryLabel][$referenceLabel] + $this->coreNeighborDistances[$queryLabel][$referenceLabel] + $this->coreNeighborDistances[$referenceLabel][$queryLabel] + $this->coreNeighborDistances[$referenceLabel][$queryLabel] + $this->nativeInterpointCache[$smallIndex][$largeIndex] + $this->nodeDistances[$smallIndex][$largeIndex] + + + $bestDistances[$label] + $bestDistances[$queryLabel] + $bestDistances[$queryLabel] + $coreNeighborDistances[$referenceLabel] + $oldCoreNeighbors[$label] + $oldCoreNeighbors[$neighborLabel] + $queryLabels[$queryKey] + $referenceLabels[$referenceKey] + $this->coreDistances[$label] + $this->coreDistances[$label] + $this->coreDistances[$label] + $this->coreDistances[$label] + $this->coreDistances[$neighborLabel] + $this->coreDistances[$queryLabel] + $this->coreDistances[$queryLabel] + $this->coreDistances[$queryLabel] + $this->coreDistances[$queryLabel] + $this->coreDistances[$referenceLabel] + $this->coreDistances[$referenceLabel] + $this->coreDistances[$referenceLabel] + $this->coreNeighborDistances[$a] + $this->coreNeighborDistances[$b] + $this->coreNeighborDistances[$label] + $this->coreNeighborDistances[$label] + $this->coreNeighborDistances[$neighborLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->nativeInterpointCache[$smallIndex] + $this->nativeInterpointCache[$smallIndex] + $this->nodeDistances[$smallIndex] + $this->nodeDistances[$smallIndex] + $updatedOldCoreLabels[$neighborLabel] + + + $this->nodeDistances[$smallIndex] + + + $bestDistance + $bestDistances[$queryLabel] + $child + $coreDistance + $coreNeighborDistances + $currentBound + $currentDistance + $label + $label + $label + $labels[] + $labels[] + $largeIndex + $largeIndex + $largeIndex + $largestOldCoreDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestLeft + $longestLeft + $longestRight + $longestRight + $neighborLabel + $nodeDistance + $oldDistances + $queryKey + $queryLabel + $queryLabels + $queryLeft + $queryLeft + $queryNodeId + $queryRight + $queryRight + $querySample + $querySamples + $radius + $referenceKey + $referenceLabel + $referenceLabels + $referenceLeft + $referenceNodeId + $referenceRight + $referenceSample + $referenceSamples + $shortestDistance + $smallIndex + $smallIndex + $smallIndex + $subChild + $subStack[] + + + float + float + float + float + + + children + getLongestDistance + getLongestDistance + getLongestDistance + labels + labels + left + left + left + radius + radius + radius + radius + radius + radius + radius + radius + radius + radius + right + right + right + samples + samples + setLongestDistance + + + $longestLeft + $longestRight + $queryNode->radius() + $queryNode->radius() + $queryNode->radius() + $shortestDistance + min($longestLeft, $longestRight) + + + $nodeDistance + $this->coreDistances[$label] + $this->coreDistances[$queryLabel] ?? INF + $this->coreNeighborDistances[$a][$b] + $this->coreNeighborDistances[$b][$a] + $this->nativeInterpointCache[$smallIndex][$largeIndex] + + + DualTreeBall + array{list<mixed>,list<float>} + array{list<mixed>,list<float>} + + + $this->dataset + $this->root + + + $sampleKey + $sampleKey + + + labels + labels + labels + + + children + dataset + dataset + dataset + dataset + + + resetFullyConnectedStatus + resetLongestEdge + + + $labels + - + $childCluster->getCoreEdges() - $finalLambda - $vertexConnectedFrom - - + $childClusterEdges1 $childClusterEdges1 $childClusterEdges2 + $childClusterEdges2 + $childClusterVerticesToEdges1 + $childClusterVerticesToEdges2 + $currentLambda + $this->coreEdges + $this->mapVerticesToEdges[$currentVertex] + $this->mapVerticesToEdges[$currentVertex] + $this->remainingEdges + $this->remainingEdges + $this->remainingEdges + $this->remainingEdges + $this->remainingEdges + $vertexConnectedFrom $vertexConnectedTo - $vertexConnectedTo - - - $a[1] - $a[1] - $b[1] - $b[1] - $currentLongestEdge[0] - $currentLongestEdge[1] + + + $a["distance"] + $a["distance"] + $b["distance"] + $b["distance"] + $currentLongestEdge["distance"] + $currentLongestEdge["vertexFrom"] + $currentLongestEdge["vertexTo"] + $edge["finalLambda"] + $edge["finalLambda"] + $edge["finalLambda"] + $edge["finalLambda"] + $edge["vertexFrom"] + $edge["vertexFrom"] + $edge["vertexFrom"] + $edge["vertexFrom"] + $edge["vertexTo"] + $edge["vertexTo"] + $edge["vertexTo"] + $edge["vertexTo"] + $edge['vertexFrom'] + $edge['vertexTo'] + $this->mapVerticesToEdges[$vertexConnectedFrom] + $this->mapVerticesToEdges[$vertexConnectedTo] - + + $edge['finalLambda'] + $this->edges[$currentLongestEdgeKey]["finalLambda"] + $this->edges[$edgeKey]['finalLambda'] + $this->edges[$edgeKey]['finalLambda'] + $this->edges[$edgeToPrune]['finalLambda'] + + + $mapVerticesToEdges[$edge['vertexFrom']] + $mapVerticesToEdges[$edge['vertexTo']] + $this->mapVerticesToEdges[$vertexConnectedFrom] + $this->mapVerticesToEdges[$vertexConnectedTo] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexTo"]] + $vertices[$edge["vertexTo"]] + $vertices[$edge["vertexTo"]] + $vertices[$edge["vertexTo"]] + + + $this->remainingEdges[$edgeKey] + $this->remainingEdges[$edgeKey] + $this->remainingEdges[$edgeToPrune] + + $childCluster $childrenWeight $currentLambda - $currentLambda $currentLongestEdge + $edge + $edge + $edge + $edge $edgeLength - $finalLambda - $lastLambda $lastLambda $this->clusterWeight + $this->coreEdges + $this->mapVerticesToEdges + $this->mapVerticesToEdges + $this->remainingEdges + $this->remainingEdges $vertexConnectedFrom + $vertexConnectedTo + $verticesInCluster[$currentVertex] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexTo"]] + $vertices[$edge["vertexTo"]] getClusterWeight getCoreEdges - + $childCluster->getClusterWeight() $currentLambda - $lastLambda $edgeLength - $edgeLength $lastLambda $this->clusterWeight - - $this->remainingEdges - count($childClusterEdges1) > count($childClusterEdges2) ? $childClusterEdges1 : $childClusterEdges2 + + $this->edges - - $vertexConnectedTo - $vertexConnectedTo + + $this->edges + $this->mapVerticesToEdges[$vertexConnectedFrom] + $this->mapVerticesToEdges[$vertexConnectedTo] + + + + + $child + $queryLeft + $queryLeft + $queryLeft + $queryRight + $queryRight + $queryRight + $referenceLeft + $referenceLeft + $referenceRight + $referenceRight + + + radius + radius + radius + radius + radius + radius + + + $edges + + + $queryNode + $referenceNode + + + $label + $longestLeft + $longestRight + $queryLabel + $queryLeft->radius() + $queryNode->radius() + $queryRight->radius() + $querySample + $referenceLabel + $referenceSample + $shortestDistance + + + $allLabels + + + $connectingEdge["vertexFrom"] + $connectingEdge["vertexTo"] + $newEdges[$queryNode->getSetId()]["distance"] + $queryLabels[$queryKey] + $referenceLabels[$referenceKey] + + + $newEdges[$querySetId] + $newEdges[$querySetId] + $newEdges[$querySetId] + $newEdges[$querySetId] + $queryLabels[$queryKey] + $referenceLabels[$referenceKey] + $vertexSets[$setId1] + $vertexSets[$setId1] + $vertexSets[$setId1] + $vertexSets[$setId2] + $vertexSets[$setId2] + $vertexToSetId[$connectingEdge["vertexFrom"]] + $vertexToSetId[$connectingEdge["vertexTo"]] + $vertexToSetId[$queryLabel] + $vertexToSetId[$referenceLabel] + $vertexToSetId[$vertexLabel] + + + $vertexSets[$setId1] + $vertexSets[$setId1] + $vertexSets[$setId2] + $vertexSets[$setId2] + + + $candidateDist + $connectingEdge + $currentBound + $edges[] + $label + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestLeft + $longestLeft + $longestRight + $longestRight + $neighborLabel + $queryKey + $queryLabel + $queryLabels + $querySample + $querySamples + $querySetId + $referenceKey + $referenceLabel + $referenceLabels + $referenceSample + $referenceSamples + $referenceSetId + $setId1 + $setId1 + $setId2 + $setId2 + $shortestDistance + $vertexLabel + $vertexToSetId[$vertexLabel] + + + dataset + dataset + dataset + dataset + labels + labels + radius + radius + samples + samples + setLongestDistance + + + $queryNode->radius() + 2 * $queryNode->radius() + min($longestLeft, $longestRight) + + + list<array{vertexFrom:int|string,vertexTo:int|string,distance:float}> + + + $queryLeft + $queryLeft + $queryLeft + $queryRight + $queryRight + $queryRight + $referenceLeft + $referenceLeft + $referenceRight + $referenceRight + + $newEdges + + + getLongestDistance + getLongestDistance + radius + radius + radius + radius + + + getLongestDistance + getLongestDistance + radius + radius + radius + radius + + + + + DataType::continuous() + + + DataType::continuous() + + + $a + $b + @@ -345,7 +976,8 @@ - + + $this->findEntities($qb) $this->findEntities($qb) $this->findEntities($qb) $this->findEntities($qb) @@ -353,13 +985,17 @@ $this->findEntity($qb) $this->findEntity($qb) + + findClusterSample + $qb->executeQuery()->fetchAll(\PDO::FETCH_COLUMN) list<string> - + FaceDetection FaceDetection + \OCA\Recognize\Db\FaceDetection[] list<\OCA\Recognize\Db\FaceDetection> list<\OCA\Recognize\Db\FaceDetection> list<\OCA\Recognize\Db\FaceDetection> @@ -554,41 +1190,24 @@ compute compute - - $key - - $detectionKeys - $firstVertex + $clusterId + $sampled - - $dataset->sample($vertex) - $dataset->sample($vertex) - $firstVertexVector - $minVertexVector + + array_map(static fn ($key) => $detections[$key], $detectionKeys) - + $unclusteredDetections[$detectionKey] - $unclusteredDetections[$key] - - - $detectionKey - $detectionKeys - $firstVertex - $flatCluster + + + $clusterId + $sampled - - getVertexKeys + + getClusterId + getClusterId - - $minVertex - - - $unconnectedVertices - - - $detection->getClusterId() !== null -