diff --git a/lib/Clustering/DualTreeBall.php b/lib/Clustering/DualTreeBall.php new file mode 100644 index 00000000..a6e6297d --- /dev/null +++ b/lib/Clustering/DualTreeBall.php @@ -0,0 +1,137 @@ +longestDistanceInNode) { + $this->longestDistanceInNode = $longestDistance; + } + } + + public function getLongestDistance(): float { + return $this->longestDistanceInNode; + } + + public function resetLongestEdge(): void { + $this->longestDistanceInNode = INF; + foreach ($this->children() as $child) { + $child->resetLongestEdge(); + } + } + + public function resetFullyConnectedStatus(): void { + $this->fullyConnected = false; + foreach ($this->children() as $child) { + $child->resetFullyConnectedStatus(); + } + } + + /** + * @return null|int|string + */ + public function getSetId() { + if (!$this->fullyConnected) { + return null; + } + return $this->setId; + } + + public function isFullyConnected(): bool { + return $this->fullyConnected; + } + + /** + * @param array $labelToSetId + * @return null|int|string + */ + public function propagateSetChanges(array &$labelToSetId) { + if ($this->fullyConnected) { + // If we are already fully connected, we just need to check if the + // set id has changed + foreach ($this->children() as $child) { + $this->setId = $child->propagateSetChanges($labelToSetId); + } + + return $this->setId; + } + + // If, and only if, both children are fully connected and in the same set id then + // we, too, are fully connected + $setId = null; + foreach ($this->children() as $child) { + $retVal = $child->propagateSetChanges($labelToSetId); + + if ($retVal === null) { + return null; + } + + if ($setId !== null && $setId !== $retVal) { + return null; + } + + $setId = $retVal; + } + + $this->setId = $setId; + $this->fullyConnected = true; + + return $this->setId; + } + + /** + * Factory method to build a hypersphere by splitting the dataset into left and right clusters. + * + * @param \Rubix\ML\Datasets\Labeled $dataset + * @param \Rubix\ML\Kernels\Distance\Distance $kernel + * @return self + */ + public static function split(Labeled $dataset, Distance $kernel): self { + $center = []; + + foreach ($dataset->features() as $column => $values) { + if ($dataset->featureType($column)->isContinuous()) { + $center[] = Stats::mean($values); + } else { + $center[] = argmax(array_count_values($values)); + } + } + + $distances = []; + + foreach ($dataset->samples() as $sample) { + $distances[] = $kernel->compute($sample, $center); + } + + $radius = max($distances) ?: 0.0; + + $leftCentroid = $dataset->sample(argmax($distances)); + + $distances = []; + + foreach ($dataset->samples() as $sample) { + $distances[] = $kernel->compute($sample, $leftCentroid); + } + + $rightCentroid = $dataset->sample(argmax($distances)); + + $subsets = $dataset->spatialSplit($leftCentroid, $rightCentroid, $kernel); + + return new self($center, $radius, $subsets); + } +} diff --git a/lib/Clustering/DualTreeClique.php b/lib/Clustering/DualTreeClique.php new file mode 100644 index 00000000..8f5d34e7 --- /dev/null +++ b/lib/Clustering/DualTreeClique.php @@ -0,0 +1,106 @@ +longestDistanceInNode = $longestDistance; + } + + public function getLongestDistance(): float { + return $this->longestDistanceInNode; + } + + public function resetLongestEdge(): void { + $this->longestDistanceInNode = INF; + } + + public function resetFullyConnectedStatus(): void { + $this->fullyConnected = false; + } + + /** + * @return int|string|null + */ + public function getSetId() { + if (!$this->fullyConnected) { + return null; + } + + return $this->setId; + } + + public function isFullyConnected(): bool { + return $this->fullyConnected; + } + + /** + * @param array $labelToSetId + * @return int|mixed|string|null + */ + public function propagateSetChanges(array &$labelToSetId) { + if ($this->fullyConnected) { + $this->setId = $labelToSetId[$this->dataset->label(0)]; + return $this->setId; + } + + $labels = $this->dataset->labels(); + + $setId = $labelToSetId[array_pop($labels)]; + + foreach ($labels as $label) { + if ($setId !== $labelToSetId[$label]) { + return null; + } + } + + $this->fullyConnected = true; + $this->setId = $setId; + + return $this->setId; + } + + /** + * Terminate a branch with a dataset. + * + * @param \Rubix\ML\Datasets\Labeled $dataset + * @param \Rubix\ML\Kernels\Distance\Distance $kernel + * @return self + */ + public static function terminate(Labeled $dataset, Distance $kernel): self { + $center = []; + + foreach ($dataset->features() as $column => $values) { + if ($dataset->featureType($column)->isContinuous()) { + $center[] = Stats::mean($values); + } else { + $center[] = argmax(array_count_values($values)); + } + } + + $distances = []; + + foreach ($dataset->samples() as $sample) { + $distances[] = $kernel->compute($sample, $center); + } + $radius = max($distances) ?: 0.0; + return new self($dataset, $center, $radius); + } +} diff --git a/lib/Clustering/HDBSCAN.php b/lib/Clustering/HDBSCAN.php new file mode 100644 index 00000000..b253c29f --- /dev/null +++ b/lib/Clustering/HDBSCAN.php @@ -0,0 +1,162 @@ +sampleSize = $sampleSize; + $this->minClusterSize = $minClusterSize; + $this->mstSolver = new MstSolver($dataset, 20, $sampleSize, $kernel, $oldCoreDistances, $useTrueMst); + } + + public function getCoreNeighborDistances(): array { + return $this->mstSolver->getCoreNeighborDistances(); + } + + /** + * Return the estimator type. + * + * @return \Rubix\ML\EstimatorType + */ + public function type(): EstimatorType { + return EstimatorType::clusterer(); + } + + /** + * Return the data types that the estimator is compatible with. + * + * @return list<\Rubix\ML\DataType> + */ + public function compatibility(): array { + return $this->mstSolver->kernel()->compatibility(); + } + + /** + * Return the settings of the hyper-parameters in an associative array. + * + * @return mixed[] + */ + public function params(): array { + return [ + 'min cluster size' => $this->minClusterSize, + 'sample size' => $this->sampleSize, + 'dual tree' => $this->mstSolver, + ]; + } + + /** + * Form clusters and make predictions from the dataset (hard clustering). + * + * @return list + */ + public function predict(): array { + // Boruvka algorithm for MST generation + $edges = $this->mstSolver->getMst(); + + // Boruvka complete, $edges now contains our mutual reachability distance MST + if ($this->mstSolver->kernel() instanceof SquaredDistance) { + foreach ($edges as &$edge) { + $edge["distance"] = sqrt($edge["distance"]); + } + } + unset($edge); + + // TODO: Min cluster separation/edge length of MstClusterer to the caller of this class + $mstClusterer = new MstClusterer($edges, null, $this->minClusterSize, null, 0.0); + $flatClusters = $mstClusterer->processCluster(); + + return $flatClusters; + } + + /** + * Return the string representation of the object. + * + * @internal + * + * @return string + */ + public function __toString(): string { + return 'HDBSCAN (' . Params::stringify($this->params()) . ')'; + } +} diff --git a/lib/Clustering/MrdBallTree.php b/lib/Clustering/MrdBallTree.php new file mode 100644 index 00000000..849c0f15 --- /dev/null +++ b/lib/Clustering/MrdBallTree.php @@ -0,0 +1,666 @@ +maxLeafSize = $maxLeafSize; + $this->sampleSize = $sampleSize; + + $this->kernel = $kernel ?? new SquaredDistance(); + $this->radiusDiamFactor = $this->kernel instanceof SquaredDistance ? 4 : 2; + + $this->nodeDistances = []; + $this->nodeIds = new \SplObjectStorage(); + } + + public function getCoreNeighborDistances(): array { + return $this->coreNeighborDistances; + } + + /** + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $queryNode + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $referenceNode + * @return float + */ + public function nodeDistance($queryNode, $referenceNode): float { + // Use cache to accelerate repeated queries + if ($this->nodeIds->contains($queryNode)) { + $queryNodeId = $this->nodeIds[$queryNode]; + } else { + $queryNodeId = $this->nodeIds->count(); + $this->nodeIds[$queryNode] = $queryNodeId; + } + + if ($this->nodeIds->contains($referenceNode)) { + $referenceNodeId = $this->nodeIds[$referenceNode]; + } else { + $referenceNodeId = $this->nodeIds->count(); + $this->nodeIds[$referenceNode] = $referenceNodeId; + } + + if ($referenceNodeId === $queryNodeId) { + return (-$this->radiusDiamFactor * $queryNode->radius()); + } + + $smallIndex = min($queryNodeId, $referenceNodeId); + $largeIndex = max($queryNodeId, $referenceNodeId); + + if (isset($this->nodeDistances[$smallIndex][$largeIndex])) { + $nodeDistance = $this->nodeDistances[$smallIndex][$largeIndex]; + } else { + $nodeDistance = $this->kernel->compute($queryNode->center(), $referenceNode->center()); + + if ($this->kernel instanceof SquaredDistance) { + $nodeDistance = sqrt($nodeDistance) - sqrt($queryNode->radius()) - sqrt($referenceNode->radius()); + $nodeDistance = abs($nodeDistance) * $nodeDistance; + } else { + $nodeDistance = $nodeDistance - $queryNode->radius() - $referenceNode->radius(); + } + + $this->nodeDistances[$smallIndex][$largeIndex] = $nodeDistance; + } + + return $nodeDistance; + } + + /** + * Get tree root. + * + * @internal + * + * @return DualTreeBall + */ + public function getRoot(): Ball { + return $this->root; + } + + /** + * Get the dataset the tree was grown on. + * + * @internal + * + * @return Labeled + */ + public function getDataset(): Labeled { + return $this->dataset; + } + + /** + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $queryNode + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $referenceNode + * @param int $k + * @param float $maxRange + * @param array $bestDistances + * @return void + */ + private function updateNearestNeighbors($queryNode, $referenceNode, $k, $maxRange, &$bestDistances): void { + $querySamples = $queryNode->dataset()->samples(); + $queryLabels = $queryNode->dataset()->labels(); + $referenceSamples = $referenceNode->dataset()->samples(); + $referenceLabels = $referenceNode->dataset()->labels(); + + $longestDistance = 0.0; + $shortestDistance = INF; + + foreach ($querySamples as $queryKey => $querySample) { + $queryLabel = $queryLabels[$queryKey]; + + $bestDistance = $bestDistances[$queryLabel]; + + foreach ($referenceSamples as $referenceKey => $referenceSample) { + $referenceLabel = $referenceLabels[$referenceKey]; + + if ($queryLabel === $referenceLabel) { + continue; + } + + // Calculate native distance + $distance = $this->cachedComputeNative($queryLabel, $querySample, $referenceLabel, $referenceSample); + + if ($distance < $bestDistance) { + //Minimize array queries within these loops: + $coreNeighborDistances = & $this->coreNeighborDistances[$queryLabel]; + $coreNeighborDistances[$referenceLabel] = $distance; + + if (count($coreNeighborDistances) >= $k) { + asort($coreNeighborDistances); + $coreNeighborDistances = array_slice($coreNeighborDistances, 0, $k, true); + $bestDistance = min(end($coreNeighborDistances), $maxRange); + } + } + } + + if ($bestDistance > $longestDistance) { + $longestDistance = $bestDistance; + } + + if ($bestDistance < $shortestDistance) { + $shortestDistance = $bestDistance; + } + $bestDistances[$queryLabel] = $bestDistance; + } + + if ($this->kernel instanceof SquaredDistance) { + $longestDistance = min($longestDistance, (2 * sqrt($queryNode->radius()) + sqrt($shortestDistance)) ** 2); + } else { + $longestDistance = min($longestDistance, 2 * $queryNode->radius() + $shortestDistance); + } + $queryNode->setLongestDistance($longestDistance); + } + + private function findNearestNeighbors($queryNode, $referenceNode, $k, $maxRange, &$bestDistances): void { + $nodeDistance = $this->nodeDistance($queryNode, $referenceNode); + + if ($nodeDistance > 0.0) { + // Calculate smallest possible bound (i.e., d(Q) ): + $currentBound = $queryNode->getLongestDistance(); + + // If node distance is greater than the longest possible edge in this node, + // prune this reference node + if ($nodeDistance > $currentBound) { + return; + } + } + + if ($queryNode instanceof DualTreeClique && $referenceNode instanceof DualTreeClique) { + $this->updateNearestNeighbors($queryNode, $referenceNode, $k, $maxRange, $bestDistances); + return; + } + + if ($queryNode instanceof DualTreeClique) { + foreach ($referenceNode->children() as $child) { + $this->findNearestNeighbors($queryNode, $child, $k, $maxRange, $bestDistances); + } + return; + } + + if ($referenceNode instanceof DualTreeClique) { + $longestDistance = 0.0; + + $queryLeft = $queryNode->left(); + $queryRight = $queryNode->right(); + + $this->findNearestNeighbors($queryLeft, $referenceNode, $k, $maxRange, $bestDistances); + $this->findNearestNeighbors($queryRight, $referenceNode, $k, $maxRange, $bestDistances); + } else { + // --> if ($queryNode instanceof DualTreeBall && $referenceNode instanceof DualTreeBall) + $queryLeft = $queryNode->left(); + $queryRight = $queryNode->right(); + $referenceLeft = $referenceNode->left(); + $referenceRight = $referenceNode->right(); + + // TODO: traverse closest neighbor nodes first + $this->findNearestNeighbors($queryLeft, $referenceLeft, $k, $maxRange, $bestDistances); + $this->findNearestNeighbors($queryLeft, $referenceRight, $k, $maxRange, $bestDistances); + $this->findNearestNeighbors($queryRight, $referenceLeft, $k, $maxRange, $bestDistances); + $this->findNearestNeighbors($queryRight, $referenceRight, $k, $maxRange, $bestDistances); + } + + $longestLeft = $queryLeft->getLongestDistance(); + $longestRight = $queryRight->getLongestDistance(); + + // TODO: min($longestLeft, $longestRight) + 2 * ($queryNode->radius()) <--- Can be made tighter by using the shortest distance from child. + if ($this->kernel instanceof SquaredDistance) { + $longestDistance = max($longestLeft, $longestRight); + $longestLeft = (sqrt($longestLeft) + 2 * (sqrt($queryNode->radius()) - sqrt($queryLeft->radius()))) ** 2; + $longestRight = (sqrt($longestRight) + 2 * (sqrt($queryNode->radius()) - sqrt($queryRight->radius()))) ** 2; + $longestDistance = min($longestDistance, min($longestLeft, $longestRight), (sqrt(min($longestLeft, $longestRight)) + 2 * (sqrt($queryNode->radius()))) ** 2); + } else { + $longestDistance = max($longestLeft, $longestRight); + $longestLeft = $longestLeft + 2 * ($queryNode->radius() - $queryLeft->radius()); + $longestRight = $longestRight + 2 * ($queryNode->radius() - $queryRight->radius()); + $longestDistance = min($longestDistance, min($longestLeft, $longestRight), min($longestLeft, $longestRight) + 2 * ($queryNode->radius())); + } + + $queryNode->setLongestDistance($longestDistance); + + return; + } + + public function kNearestAll($k, float $maxRange = INF): void { + $this->coreNeighborDistances = []; + + $allLabels = $this->dataset->labels(); + $bestDistances = []; + foreach ($allLabels as $label) { + $bestDistances[$label] = $maxRange; + } + + if ($this->root === null) { + return; + } + + $treeRoot = $this->root; + + $treeRoot->resetFullyConnectedStatus(); + $treeRoot->resetLongestEdge(); + + $this->findNearestNeighbors($treeRoot, $treeRoot, $k, $maxRange, $bestDistances); + } + + /** + * Precompute core distances for the current dataset to accelerate + * subsequent MRD queries. Optionally, utilize core distances that've + * been previously determined for (a subset of) the current dataset. + * Returns the updated core distances for future use. + * + * @internal + * + * @param array|null $oldCoreNeighbors + * @return array + */ + + public function precalculateCoreDistances(?array $oldCoreNeighbors = null) { + if (empty($this->dataset)) { + throw new \Exception("Precalculation of core distances requested but dataset is empty. Call ->grow() first!"); + } + + $labels = $this->dataset->labels(); + + if ($oldCoreNeighbors !== null && !empty($oldCoreNeighbors) && count(reset($oldCoreNeighbors)) >= $this->sampleSize) { + // Determine the search radius for core distances based on the largest old + // core distance (points father than that cannot shorten the old core distances) + $largestOldCoreDistance = 0.0; + + // Utilize old (possibly stale) core distance data + foreach ($oldCoreNeighbors as $label => $oldDistances) { + $coreDistance = (array_values($oldDistances))[$this->sampleSize - 1]; + + if ($coreDistance > $largestOldCoreDistance) { + $largestOldCoreDistance = $coreDistance; + } + + $this->coreNeighborDistances[$label] = $oldDistances; + $this->coreDistances[$label] = $coreDistance; + } + + $updatedOldCoreLabels = []; + + // Don't recalculate core distances for the old labels + $labels = array_filter($labels, function ($label) use ($oldCoreNeighbors) { + return !isset($oldCoreNeighbors[$label]); + }); + + foreach ($labels as $label) { + [$neighborLabels, $neighborDistances] = $this->cachedRange($label, $largestOldCoreDistance); + // TODO: cachedRange may not return $this->sampleSize number of labels. + $this->coreNeighborDistances[$label] = array_combine(array_slice($neighborLabels, 0, $this->sampleSize), array_slice($neighborDistances, 0, $this->sampleSize)); + $this->coreDistances[$label] = $neighborDistances[$this->sampleSize - 1]; + + // If one of the old vertices is within the update radius of this new vertex, + // check whether the old core distance needs to be updated. + foreach ($neighborLabels as $distanceKey => $neighborLabel) { + if (isset($oldCoreNeighbors[$neighborLabel])) { + $newDistance = $neighborDistances[$distanceKey]; + if ($newDistance < $this->coreDistances[$neighborLabel]) { + $this->coreNeighborDistances[$neighborLabel][$label] = $newDistance; + $updatedOldCoreLabels[$neighborLabel] = true; + } + } + } + } + + foreach (array_keys($updatedOldCoreLabels) as $label) { + asort($this->coreNeighborDistances[$label]); + $this->coreNeighborDistances[$label] = array_slice($this->coreNeighborDistances[$label], 0, $this->sampleSize, true); + $this->coreDistances[$label] = end($this->coreNeighborDistances[$label]); + } + } else { // $oldCoreNeighbors === null + $this->kNearestAll($this->sampleSize, INF); + + foreach ($this->dataset->labels() as $label) { + $this->coreDistances[$label] = end($this->coreNeighborDistances[$label]); + } + } + + return $this->coreNeighborDistances; + } + + /** + * Inserts a new neighbor to core neighbors if the distance + * is greater than the current largest distance for the query label. + * + * Returns the updated core distance or INF if there are less than $this->sampleSize neighbors. + * + * @internal + * + * @param mixed $queryLabel + * @param mixed $referenceLabel + * @param float $distance + * @return float + */ + private function insertToCoreDistances($queryLabel, $referenceLabel, $distance): float { + // Update the core distances of the queryLabel + if (isset($this->coreDistances[$queryLabel])) { + if ($this->coreDistances[$queryLabel] > $distance) { + $this->coreNeighborDistances[$queryLabel][$referenceLabel] = $distance; + asort($this->coreNeighborDistances[$queryLabel]); + + $this->coreNeighborDistances[$queryLabel] = array_slice($this->coreNeighborDistances[$queryLabel], 0, $this->sampleSize, true); + $this->coreDistances[$queryLabel] = end($this->coreNeighborDistances[$queryLabel]); + } + } else { + $this->coreNeighborDistances[$queryLabel][$referenceLabel] = $distance; + + if (count($this->coreNeighborDistances[$queryLabel]) >= $this->sampleSize) { + asort($this->coreNeighborDistances[$queryLabel]); + + $this->coreNeighborDistances[$queryLabel] = array_slice($this->coreNeighborDistances[$queryLabel], 0, $this->sampleSize, true); + $this->coreDistances[$queryLabel] = end($this->coreNeighborDistances[$queryLabel]); + } + } + + // Update the core distances of the referenceLabel (this is not necessary, but *may* accelerate the algo slightly) + if (isset($this->coreDistances[$referenceLabel])) { + if ($this->coreDistances[$referenceLabel] > $distance) { + $this->coreNeighborDistances[$referenceLabel][$queryLabel] = $distance; + asort($this->coreNeighborDistances[$referenceLabel]); + + $this->coreNeighborDistances[$referenceLabel] = array_slice($this->coreNeighborDistances[$referenceLabel], 0, $this->sampleSize, true); + $this->coreDistances[$referenceLabel] = end($this->coreNeighborDistances[$referenceLabel]); + } + } else { + $this->coreNeighborDistances[$referenceLabel][$queryLabel] = $distance; + + if (count($this->coreNeighborDistances[$referenceLabel]) > $this->sampleSize) { + asort($this->coreNeighborDistances[$referenceLabel]); + $this->coreNeighborDistances[$referenceLabel] = array_slice($this->coreNeighborDistances[$referenceLabel], 0, $this->sampleSize, true); + $this->coreDistances[$referenceLabel] = end($this->coreNeighborDistances[$referenceLabel]); + } + } + + return $this->coreDistances[$queryLabel] ?? INF; + } + + /** + * Compute the mutual reachability distance between two vectors. + * + * @internal + * + * @param int|string $a + * @param int|string $b + * @return float + */ + public function computeMrd($a, array $a_vector, $b, array $b_vector): float { + $distance = $this->cachedComputeNative($a, $a_vector, $b, $b_vector); + + return max($distance, $this->getCoreDistance($a), $this->getCoreDistance($b)); + } + + public function getCoreDistance($label): float { + if (!isset($this->coreDistances[$label])) { + [$labels, $distances] = $this->getCoreNeighbors($label); + $this->coreDistances[$label] = end($distances); + } + + return $this->coreDistances[$label]; + } + + public function cachedComputeNative($a, array $a_vector, $b, array $b_vector, bool $storeNewCalculations = true): float { + if (isset($this->coreNeighborDistances[$a][$b])) { + return $this->coreNeighborDistances[$a][$b]; + } + if (isset($this->coreNeighborDistances[$b][$a])) { + return $this->coreNeighborDistances[$b][$a]; + } + + if ($a < $b) { + $smallIndex = $a; + $largeIndex = $b; + } else { + $smallIndex = $b; + $largeIndex = $a; + } + + if (!isset($this->nativeInterpointCache[$smallIndex][$largeIndex])) { + $distance = $this->kernel->compute($a_vector, $b_vector); + if ($storeNewCalculations) { + $this->nativeInterpointCache[$smallIndex][$largeIndex] = $distance; + } + return $distance; + } + + return $this->nativeInterpointCache[$smallIndex][$largeIndex]; + } + + /** + * Run a n nearest neighbors search on a single label and return the neighbor labels, and distances in a 2-tuple + * + * + * @internal + * + * @param int|string $sampleLabel + * @param bool $useCachedValues + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + * @return array{list,list} + */ + public function getCoreNeighbors($sampleLabel, bool $useCachedValues = true): array { + if ($useCachedValues && isset($this->coreNeighborDistances[$sampleLabel])) { + return [array_keys($this->coreNeighborDistances[$sampleLabel]), array_values($this->coreNeighborDistances[$sampleLabel])]; + } + + $sampleKey = array_search($sampleLabel, $this->dataset->labels()); + $sample = $this->dataset->sample($sampleKey); + + $squaredDistance = $this->kernel instanceof SquaredDistance; + + /** @var list **/ + $stack = [$this->root]; + $stackDistances = [0.0]; + $radius = INF; + + $labels = $distances = []; + + while ($current = array_pop($stack)) { + $currentDistance = array_pop($stackDistances); + + if ($currentDistance > $radius) { + continue; + } + + if ($current instanceof DualTreeBall) { + foreach ($current->children() as $child) { + if ($child instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $child->center()); + + if ($squaredDistance) { + $distance = sqrt($distance); + $childRadius = sqrt($child->radius()); + $distance = $distance - $childRadius; + $distance = abs($distance) * $distance; + } else { + $distance = $distance - $child->radius(); + } + + if ($distance < $radius) { + $stack[] = $child; + $stackDistances[] = $distance; + } + } + } + array_multisort($stackDistances, SORT_DESC, $stack); + } elseif ($current instanceof DualTreeClique) { + $dataset = $current->dataset(); + $neighborLabels = $dataset->labels(); + + foreach ($dataset->samples() as $i => $neighbor) { + if ($neighborLabels[$i] === $sampleLabel) { + continue; + } + + $distance = $this->cachedComputeNative($sampleLabel, $sample, $neighborLabels[$i], $neighbor); + + if ($distance <= $radius) { + $labels[] = $neighborLabels[$i]; + $distances[] = $distance; + } + } + + if (count($labels) >= $this->sampleSize) { + array_multisort($distances, $labels); + $radius = $distances[$this->sampleSize - 1]; + $labels = array_slice($labels, 0, $this->sampleSize); + $distances = array_slice($distances, 0, $this->sampleSize); + } + } + } + return [$labels, $distances]; + } + + /** + * Return all labels, and distances within a given radius of a sample. + * + * + * @internal + * + * @param int $sampleLabel + * @param float $radius + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + * @throws \Rubix\ML\Exceptions\RuntimeException + * @return array{list,list} + */ + public function cachedRange($sampleLabel, float $radius): array { + $sampleKey = array_search($sampleLabel, $this->dataset->labels()); + $sample = $this->dataset->sample($sampleKey); + + $squaredDistance = $this->kernel instanceof SquaredDistance; + + /** @var list **/ + $stack = [$this->root]; + + $labels = $distances = []; + + while ($current = array_pop($stack)) { + if ($current instanceof DualTreeBall) { + foreach ($current->children() as $child) { + if ($child instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $child->center()); + + if ($squaredDistance) { + $distance = sqrt($distance); + $childRadius = sqrt($child->radius()); + $minDistance = $distance - $childRadius; + $minDistance = abs($minDistance) * $minDistance; + $maxDistance = ($distance + $childRadius) ** 2; + } else { + $childRadius = $child->radius(); + $minDistance = $distance - $childRadius; + $maxDistance = $distance + $childRadius; + } + + if ($minDistance < $radius) { + if ($maxDistance < $radius && $child instanceof DualTreeBall) { + // The whole child is within the specified radius: greedily add all sub-children recursively to the stack + $subStack = [$child]; + while ($subStackCurrent = array_pop($subStack)) { + foreach ($subStackCurrent->children() as $subChild) { + if ($subChild instanceof DualTreeClique) { + $stack[] = $subChild; + } else { + $subStack[] = $subChild; + } + } + } + } else { + $stack[] = $child; + } + } + } + } + } elseif ($current instanceof DualTreeClique) { + $dataset = $current->dataset(); + $neighborLabels = $dataset->labels(); + + foreach ($dataset->samples() as $i => $neighbor) { + $distance = $this->cachedComputeNative($sampleLabel, $sample, $neighborLabels[$i], $neighbor); + + if ($distance <= $radius) { + $labels[] = $neighborLabels[$i]; + $distances[] = $distance; + } + } + } + } + array_multisort($distances, $labels); + return [$labels, $distances]; + } + + /** + * Insert a root node and recursively split the dataset until a terminating + * condition is met. This also sets the dataset that will be used to calculate + * core distances. Previously calculated core distances will be stored/used + * despite calling grow, unless precalculateCoreDistances() is called again. + * + * @internal + * + * @param \Rubix\ML\Datasets\Labeled $dataset + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + */ + public function grow(Labeled $dataset): void { + $this->dataset = $dataset; + $this->root = DualTreeBall::split($dataset, $this->kernel); + + $stack = [$this->root]; + + while ($current = array_pop($stack)) { + [$left, $right] = $current->subsets(); + + $current->cleanup(); + + if ($left->numSamples() > $this->maxLeafSize) { + $node = DualTreeBall::split($left, $this->kernel); + + $current->attachLeft($node); + + $stack[] = $node; + } elseif (!$left->empty()) { + $current->attachLeft(DualTreeClique::terminate($left, $this->kernel)); + } + + if ($right->numSamples() > $this->maxLeafSize) { + $node = DualTreeBall::split($right, $this->kernel); + + if ($node->isPoint()) { + $current->attachRight(DualTreeClique::terminate($right, $this->kernel)); + } else { + $current->attachRight($node); + + $stack[] = $node; + } + } elseif (!$right->empty()) { + $current->attachRight(DualTreeClique::terminate($right, $this->kernel)); + } + } + } +} diff --git a/lib/Clustering/MstClusterer.php b/lib/Clustering/MstClusterer.php new file mode 100644 index 00000000..2f77b5ab --- /dev/null +++ b/lib/Clustering/MstClusterer.php @@ -0,0 +1,272 @@ + + */ + private array $remainingEdges; + private float $startingLambda; + private float $clusterWeight; + private int $minimumClusterSize; + private array $coreEdges; + private bool $isRoot; + private array $mapVerticesToEdges; + private float $minClusterSeparation; + + public function __construct(array $edges, ?array $mapVerticesToEdges, int $minimumClusterSize, ?float $startingLambda = null, float $minClusterSeparation = 0.1) { + //Ascending sort of edges while perserving original keys. + $this->edges = $edges; + + uasort($this->edges, function ($a, $b) { + if ($a["distance"] > $b["distance"]) { + return 1; + } + if ($a["distance"] < $b["distance"]) { + return -1; + } + return 0; + }); + + $this->remainingEdges = $this->edges; + + if ($mapVerticesToEdges === null) { + $mapVerticesToEdges = []; + foreach ($this->edges as $edgeIndex => $edge) { + $mapVerticesToEdges[$edge['vertexFrom']][$edgeIndex] = true; + $mapVerticesToEdges[$edge['vertexTo']][$edgeIndex] = true; + } + } + + $this->mapVerticesToEdges = $mapVerticesToEdges; + + if (is_null($startingLambda)) { + $this->isRoot = true; + $this->startingLambda = 0.0; + } else { + $this->isRoot = false; + $this->startingLambda = $startingLambda; + } + + $this->minimumClusterSize = $minimumClusterSize; + + $this->coreEdges = []; + + $this->clusterWeight = 0.0; + + + $this->minClusterSeparation = $minClusterSeparation; + } + + public function processCluster(): array { + $currentLambda = $lastLambda = $this->startingLambda; + + while (true) { + $edgeCount = count($this->remainingEdges); + + if ($edgeCount < ($this->minimumClusterSize - 1)) { + foreach ($this->coreEdges as &$edge) { + $edge['finalLambda'] = $currentLambda; + } + unset($edge); + + foreach (array_keys($this->remainingEdges) as $edgeKey) { + $this->edges[$edgeKey]['finalLambda'] = $currentLambda; + } + + return [$this]; + } + + if ($edgeCount < (2 * $this->minimumClusterSize - 1)) { + // The end is near; this cluster cannot be split into two anymore + $this->coreEdges = $this->remainingEdges; + } + + $currentLongestEdgeKey = array_key_last($this->remainingEdges); + $currentLongestEdge = array_pop($this->remainingEdges); + + $vertexConnectedFrom = $currentLongestEdge["vertexFrom"]; + $vertexConnectedTo = $currentLongestEdge["vertexTo"]; + $edgeLength = $currentLongestEdge["distance"]; + + unset($this->mapVerticesToEdges[$vertexConnectedFrom][$currentLongestEdgeKey]); + unset($this->mapVerticesToEdges[$vertexConnectedTo][$currentLongestEdgeKey]); + + if ($edgeLength > 0.0) { + $currentLambda = 1 / $edgeLength; + } + + $this->clusterWeight += ($currentLambda - $lastLambda) * $edgeCount; + $lastLambda = $currentLambda; + + $this->edges[$currentLongestEdgeKey]["finalLambda"] = $currentLambda; + + if (!$this->pruneFromCluster($vertexConnectedTo, $currentLambda) && !$this->pruneFromCluster($vertexConnectedFrom, $currentLambda)) { + // This cluster will (probably) split into two child clusters: + + [$childClusterEdges1, $childClusterVerticesToEdges1] = $this->getChildClusterComponents($vertexConnectedTo); + [$childClusterEdges2, $childClusterVerticesToEdges2] = $this->getChildClusterComponents($vertexConnectedFrom); + + if ($edgeLength < $this->minClusterSeparation) { + if (count($childClusterEdges1) > count($childClusterEdges2)) { + $this->remainingEdges = $childClusterEdges1; + $this->mapVerticesToEdges = $childClusterVerticesToEdges1; + } else { + $this->remainingEdges = $childClusterEdges2; + $this->mapVerticesToEdges = $childClusterVerticesToEdges2; + } + continue; + } + + // Choose clusters using excess of mass method: + // Return a list of children if the weight of all children is more than $this->clusterWeight. + // Otherwise return the current cluster and discard the children. This way we "choose" a combination + // of clusters that weigh the most (i.e. have most (excess of) mass). Always discard the root cluster. + + + $childCluster1 = new MstClusterer($childClusterEdges1, $childClusterVerticesToEdges1, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation); + $childCluster2 = new MstClusterer($childClusterEdges2, $childClusterVerticesToEdges2, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation); + + // Resolve all chosen child clusters recursively + $childClusters = array_merge($childCluster1->processCluster(), $childCluster2->processCluster()); + + $childrenWeight = 0.0; + foreach ($childClusters as $childCluster) { + $childrenWeight += $childCluster->getClusterWeight(); + $this->coreEdges = array_merge($this->coreEdges, $childCluster->getCoreEdges()); + } + + if (($childrenWeight > $this->clusterWeight) || $this->isRoot) { + return $childClusters; + } else { + foreach (array_keys($this->remainingEdges) as $edgeKey) { + $this->edges[$edgeKey]['finalLambda'] = $currentLambda; + } + } + + return [$this]; + } + } + } + + private function pruneFromCluster(int $vertexId, float $currentLambda): bool { + $edgeIndicesToPrune = []; + $verticesToPrune = []; + $vertexStack = [$vertexId]; + + while (!empty($vertexStack)) { + $currentVertex = array_pop($vertexStack); + $verticesToPrune[] = $currentVertex; + + if (count($verticesToPrune) >= $this->minimumClusterSize) { + return false; + } + + foreach (array_keys($this->mapVerticesToEdges[$currentVertex]) as $edgeKey) { + if (isset($edgeIndicesToPrune[$edgeKey])) { + continue; + } + + if ($this->remainingEdges[$edgeKey]["vertexFrom"] === $currentVertex) { + $vertexStack[] = $this->remainingEdges[$edgeKey]["vertexTo"]; + $edgeIndicesToPrune[$edgeKey] = true; + } elseif ($this->remainingEdges[$edgeKey]["vertexTo"] === $currentVertex) { + $vertexStack[] = $this->remainingEdges[$edgeKey]["vertexFrom"]; + $edgeIndicesToPrune[$edgeKey] = true; + } + } + } + + // Prune edges + foreach (array_keys($edgeIndicesToPrune) as $edgeToPrune) { + $this->edges[$edgeToPrune]['finalLambda'] = $currentLambda; + unset($this->remainingEdges[$edgeToPrune]); + } + + // Prune vertices to edges map (not stricly necessary but saves some memory) + foreach ($verticesToPrune as $vertexLabel) { + unset($this->mapVerticesToEdges[$vertexLabel]); + } + + return true; + } + + private function getChildClusterComponents(int $vertexId): array { + $vertexStack = [$vertexId]; + $edgeIndicesInCluster = []; + $verticesInCluster = []; + + while (!empty($vertexStack)) { + $currentVertex = array_pop($vertexStack); + $verticesInCluster[$currentVertex] = $this->mapVerticesToEdges[$currentVertex]; + + foreach (array_keys($this->mapVerticesToEdges[$currentVertex]) as $edgeKey) { + if (isset($edgeIndicesInCluster[$edgeKey])) { + continue; + } + + if ($this->remainingEdges[$edgeKey]["vertexFrom"] === $currentVertex) { + $vertexStack[] = $this->remainingEdges[$edgeKey]["vertexTo"]; + $edgeIndicesInCluster[$edgeKey] = true; + } elseif ($this->remainingEdges[$edgeKey]["vertexTo"] === $currentVertex) { + $vertexStack[] = $this->remainingEdges[$edgeKey]["vertexFrom"]; + $edgeIndicesInCluster[$edgeKey] = true; + } + } + } + + // Collecting the edges is done in a separate loop to perserve the ordering according to length. + // (See constructor.) + $edgesInCluster = []; + foreach ($this->remainingEdges as $edgeKey => $edge) { + if (isset($edgeIndicesInCluster[$edgeKey])) { + $edgesInCluster[$edgeKey] = $edge; + } + } + + return [$edgesInCluster, $verticesInCluster]; + } + + public function getClusterWeight(): float { + return $this->clusterWeight; + } + + public function getClusterVertices(): array { + $vertices = []; + + foreach ($this->edges as $edge) { + $vertices[$edge["vertexTo"]] = min($edge["finalLambda"], $vertices[$edge["vertexTo"]] ?? INF); + $vertices[$edge["vertexFrom"]] = min($edge["finalLambda"], $vertices[$edge["vertexFrom"]] ?? INF); + } + + return $vertices; + } + + public function getCoreEdges(): array { + return $this->coreEdges; + } + + public function getClusterEdges(): array { + return $this->edges; + } + + public function getCoreVertices(): array { + $vertices = []; + + foreach ($this->coreEdges as $edge) { + $vertices[$edge["vertexTo"]] = min($edge["finalLambda"], $vertices[$edge["vertexTo"]] ?? INF); + $vertices[$edge["vertexFrom"]] = min($edge["finalLambda"], $vertices[$edge["vertexFrom"]] ?? INF); + } + + return $vertices; + } +} diff --git a/lib/Clustering/MstSolver.php b/lib/Clustering/MstSolver.php new file mode 100644 index 00000000..d88a3383 --- /dev/null +++ b/lib/Clustering/MstSolver.php @@ -0,0 +1,281 @@ +kernel = $kernel ?? new SquaredDistance(); + + $this->tree = new MrdBallTree($maxLeafSize, $sampleSize, $this->kernel); + + $this->tree->grow($fullDataset); + $this->tree->precalculateCoreDistances($oldCoreDistances); + + $this->useTrueMst = $useTrueMst; + } + + public function kernel(): Distance { + return $this->kernel; + } + + public function getCoreNeighborDistances(): array { + return $this->tree->getCoreNeighborDistances(); + } + + public function getTree(): MrdBallTree { + return $this->tree; + } + + private function updateEdges($queryNode, $referenceNode, array &$newEdges, array &$vertexToSetId): void { + $querySamples = $queryNode->dataset()->samples(); + $queryLabels = $queryNode->dataset()->labels(); + $referenceSamples = $referenceNode->dataset()->samples(); + $referenceLabels = $referenceNode->dataset()->labels(); + + $longestDistance = 0.0; + $shortestDistance = INF; + + foreach ($querySamples as $queryKey => $querySample) { + $queryLabel = $queryLabels[$queryKey]; + $querySetId = $vertexToSetId[$queryLabel]; + + if ($this->tree->getCoreDistance($queryLabel) > ($newEdges[$querySetId]["distance"] ?? INF)) { + // The core distance of the current vertex is greater than the current best edge + // for this setId. This means that the MRD will always be greater than the current best. + continue; + } + + foreach ($referenceSamples as $referenceKey => $referenceSample) { + $referenceLabel = $referenceLabels[$referenceKey]; + $referenceSetId = $vertexToSetId[$referenceLabel]; + + if ($querySetId === $referenceSetId) { + continue; + } + + $distance = $this->tree->computeMrd($queryLabel, $querySample, $referenceLabel, $referenceSample); + + if ($distance < ($newEdges[$querySetId]["distance"] ?? INF)) { + $newEdges[$querySetId] = ["vertexFrom" => $queryLabel, "vertexTo" => $referenceLabel, "distance" => $distance]; + } + } + $candidateDist = $newEdges[$querySetId]["distance"] ?? INF; + if ($candidateDist > $longestDistance) { + $longestDistance = $candidateDist; + } + + if ($candidateDist < $shortestDistance) { + $shortestDistance = $candidateDist; + } + } + + // Update the bound of the query node + if ($this->kernel instanceof SquaredDistance) { + $longestDistance = min($longestDistance, (2 * sqrt($queryNode->radius()) + sqrt($shortestDistance)) ** 2); + } else { + $longestDistance = min($longestDistance, 2 * $queryNode->radius() + $shortestDistance); + } + + $queryNode->setLongestDistance($longestDistance); + } + + /** + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $queryNode + * @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $referenceNode + * @param array $newEdges + * @param array $vertexToSetId + * @return void + */ + private function findSetNeighbors($queryNode, $referenceNode, array &$newEdges, array &$vertexToSetId): void { + if ($queryNode->isFullyConnected() && $referenceNode->isFullyConnected()) { + if ($queryNode->getSetId() === $referenceNode->getSetId()) { + // These nodes are connected and in the same set, so we can prune this reference node. + return; + } + } + + // if d(Q,R) > d(Q) then + // return; + + $nodeDistance = $this->tree->nodeDistance($queryNode, $referenceNode); + + if ($nodeDistance > 0.0) { + // Calculate smallest possible bound (i.e., d(Q) ): + if ($queryNode->isFullyConnected()) { + $currentBound = min($newEdges[$queryNode->getSetId()]["distance"] ?? INF, $queryNode->getLongestDistance()); + } else { + $currentBound = $queryNode->getLongestDistance(); + } + // If node distance is greater than the longest possible edge in this node, + // prune this reference node + if ($nodeDistance > $currentBound) { + return; + } + } + + if ($queryNode instanceof DualTreeClique && $referenceNode instanceof DualTreeClique) { + $this->updateEdges($queryNode, $referenceNode, $newEdges, $vertexToSetId); + return; + } + + if ($queryNode instanceof DualTreeClique) { + foreach ($referenceNode->children() as $child) { + $this->findSetNeighbors($queryNode, $child, $newEdges, $vertexToSetId); + } + return; + } + + if ($referenceNode instanceof DualTreeClique) { + $longestDistance = 0.0; + + $queryLeft = $queryNode->left(); + $queryRight = $queryNode->right(); + + $this->findSetNeighbors($queryLeft, $referenceNode, $newEdges, $vertexToSetId); + $this->findSetNeighbors($queryRight, $referenceNode, $newEdges, $vertexToSetId); + } else { // if ($queryNode instanceof DualTreeBall && $referenceNode instanceof DualTreeBall) + $queryLeft = $queryNode->left(); + $queryRight = $queryNode->right(); + $referenceLeft = $referenceNode->left(); + $referenceRight = $referenceNode->right(); + + $this->findSetNeighbors($queryLeft, $referenceLeft, $newEdges, $vertexToSetId); + $this->findSetNeighbors($queryRight, $referenceRight, $newEdges, $vertexToSetId); + $this->findSetNeighbors($queryLeft, $referenceRight, $newEdges, $vertexToSetId); + $this->findSetNeighbors($queryRight, $referenceLeft, $newEdges, $vertexToSetId); + } + + $longestLeft = $queryLeft->getLongestDistance(); + $longestRight = $queryRight->getLongestDistance(); + + // TODO: min($longestLeft, $longestRight) + 2 * ($queryNode->radius()) <--- Can be made tighter? + if ($this->kernel instanceof SquaredDistance) { + $longestDistance = max($longestLeft, $longestRight); + $longestLeft = (sqrt($longestLeft) + 2 * (sqrt($queryNode->radius()) - sqrt($queryLeft->radius()))) ** 2; + $longestRight = (sqrt($longestRight) + 2 * (sqrt($queryNode->radius()) - sqrt($queryRight->radius()))) ** 2; + $longestDistance = min($longestDistance, min($longestLeft, $longestRight), (sqrt(min($longestLeft, $longestRight)) + 2 * (sqrt($queryNode->radius()))) ** 2); + } else { + $longestDistance = max($longestLeft, $longestRight); + $longestLeft = $longestLeft + 2 * ($queryNode->radius() - $queryLeft->radius()); + $longestRight = $longestRight + 2 * ($queryNode->radius() - $queryRight->radius()); + $longestDistance = min($longestDistance, min($longestLeft, $longestRight), min($longestLeft, $longestRight) + 2 * ($queryNode->radius())); + } + + $queryNode->setLongestDistance($longestDistance); + + return; + } + + /** + * @return list + */ + public function getMst(): array { + $edges = []; + + // MST generation using dual-tree boruvka algorithm + + $treeRoot = $this->tree->getRoot(); + + $treeRoot->resetFullyConnectedStatus(); + + $allLabels = $this->tree->getDataset()->labels(); + + $vertexToSetId = array_combine($allLabels, range(0, count($allLabels) - 1)); + + $vertexSets = []; + foreach ($vertexToSetId as $vertex => $setId) { + $vertexSets[$setId] = [$vertex]; + } + + if (!$this->useTrueMst) { + $treeRoot->resetLongestEdge(); + } + + // Use nearest neighbors known from determining core distances for each vertex to + // get the first set of $newEdges (we essentially can skip the first round of Boruvka): + $newEdges = []; + + foreach ($allLabels as $label) { + [$coreNeighborLabels, $coreNeighborDistances] = $this->tree->getCoreNeighbors($label); + + $coreDistance = end($coreNeighborDistances); + + foreach ($coreNeighborLabels as $neighborLabel) { + if ($neighborLabel === $label) { + continue; + } + + if ($this->tree->getCoreDistance($neighborLabel) <= $coreDistance) { + // This point is our nearest neighbor in mutual reachability terms, so + // an edge spanning these vertices will belong to the MST. + $newEdges[] = ["vertexFrom" => $label, "vertexTo" => $neighborLabel, "distance" => $coreDistance]; + break; + } + } + } + + /////////////////////////////////////////////////////////////////////////////// + // Main dual tree Boruvka loop: + + while (true) { + //Add new edges + //Update vertex to set/set to vertex mappings + foreach ($newEdges as $connectingEdge) { + $setId1 = $vertexToSetId[$connectingEdge["vertexFrom"]]; + $setId2 = $vertexToSetId[$connectingEdge["vertexTo"]]; + + if ($setId1 === $setId2) { + // These sets have already been merged earlier in this loop + continue; + } + + $edges[] = $connectingEdge; + + if (count($vertexSets[$setId1]) < count($vertexSets[$setId2])) { + // Make a switch such that the larger set is always Id1 + [$setId1, $setId2] = [$setId2, $setId1]; + } + + // Assign all vertices in set 2 to set 1 + foreach ($vertexSets[$setId2] as $vertexLabel) { + $vertexToSetId[$vertexLabel] = $setId1; + } + + $vertexSets[$setId1] = array_merge($vertexSets[$setId1], $vertexSets[$setId2]); + unset($vertexSets[$setId2]); + } + + // Check for exit condition + if (count($vertexSets) === 1) { + break; + } + + //Update the tree + if ($this->useTrueMst || empty($newEdges)) { + $treeRoot->resetLongestEdge(); + } + + if (!empty($newEdges)) { + $treeRoot->propagateSetChanges($vertexToSetId); + } + + // Clear the array for a set of new edges + $newEdges = []; + + $this->findSetNeighbors($treeRoot, $treeRoot, $newEdges, $vertexToSetId); + } + + return $edges; + } +} diff --git a/lib/Clustering/SquaredDistance.php b/lib/Clustering/SquaredDistance.php new file mode 100644 index 00000000..6b4fda94 --- /dev/null +++ b/lib/Clustering/SquaredDistance.php @@ -0,0 +1,64 @@ + + */ + public function compatibility(): array { + return [ + DataType::continuous(), + ]; + } + + /** + * Compute the distance between two vectors. + * + * @internal + * + * @param list $a + * @param list $b + * @return float + */ + public function compute(array $a, array $b): float { + $distance = 0.0; + + foreach ($a as $i => $value) { + $distance += ($value - $b[$i]) ** 2; + } + + return $distance; + } + + /** + * Return the string representation of the object. + * + * @internal + * + * @return string + */ + public function __toString(): string { + return 'Squared distance'; + } +} diff --git a/lib/Db/FaceDetectionMapper.php b/lib/Db/FaceDetectionMapper.php index 231e88ab..db5997b1 100644 --- a/lib/Db/FaceDetectionMapper.php +++ b/lib/Db/FaceDetectionMapper.php @@ -9,12 +9,16 @@ use OCP\AppFramework\Db\Entity; use OCP\AppFramework\Db\QBMapper; use OCP\DB\QueryBuilder\IQueryBuilder; +use OCP\IConfig; use OCP\IDBConnection; class FaceDetectionMapper extends QBMapper { - public function __construct(IDBConnection $db) { + private IConfig $config; + + public function __construct(IDBConnection $db, IConfig $config) { parent::__construct($db, 'recognize_face_detections', FaceDetection::class); $this->db = $db; + $this->config = $config; } /** @@ -130,6 +134,36 @@ public function findUserIds() :array { return $qb->executeQuery()->fetchAll(\PDO::FETCH_COLUMN); } + /** + * @param string $userId + * @return \OCA\Recognize\Db\FaceDetection[] + * @throws \OCP\DB\Exception + */ + public function findUnclusteredByUserId(string $userId) : array { + $qb = $this->db->getQueryBuilder(); + $qb->select(FaceDetection::$columns) + ->from('recognize_face_detections') + ->where($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId))) + ->andWhere($qb->expr()->isNull('cluster_id')); + return $this->findEntities($qb); + } + + public function findClusterSample(int $clusterId, int $n) { + $qb = $this->db->getQueryBuilder(); + $qb->select(FaceDetection::$columns) + ->from('recognize_face_detections', 'd') + ->where($qb->expr()->eq('cluster_id', $qb->createPositionalParameter($clusterId))) + ->orderBy( + $qb->createFunction( + $this->config->getSystemValue('dbtype', 'sqlite') === 'mysql' + ? 'RAND()' + : 'RANDOM()' + ) + ) + ->setMaxResults($n); + return $this->findEntities($qb); + } + protected function mapRowToEntity(array $row): Entity { try { return parent::mapRowToEntity($row); diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index 29651769..c4683583 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -6,30 +6,29 @@ namespace OCA\Recognize\Service; +use OCA\Recognize\Clustering\HDBSCAN; use OCA\Recognize\Db\FaceCluster; use OCA\Recognize\Db\FaceClusterMapper; use OCA\Recognize\Db\FaceDetection; use OCA\Recognize\Db\FaceDetectionMapper; -use Rubix\ML\Clusterers\DBSCAN; -use Rubix\ML\Datasets\Unlabeled; -use Rubix\ML\Graph\Trees\BallTree; +use Rubix\ML\Datasets\Labeled; use Rubix\ML\Kernels\Distance\Euclidean; class FaceClusterAnalyzer { - public const MIN_CLUSTER_DENSITY = 6; - public const MAX_INNER_CLUSTER_RADIUS = 0.44; + public const MIN_DATASET_SIZE = 30; + public const MIN_SAMPLE_SIZE = 4; // Conservative value: 10 + public const MIN_CLUSTER_SIZE = 5; // Conservative value: 10 + public const MIN_DETECTION_SIZE = 0.03; public const DIMENSIONS = 128; - public const MIN_DETECTION_SIZE = 0.09; + public const SAMPLE_SIZE_EXISTING_CLUSTERS = 42; private FaceDetectionMapper $faceDetections; private FaceClusterMapper $faceClusters; - private TagManager $tagManager; private Logger $logger; - public function __construct(FaceDetectionMapper $faceDetections, FaceClusterMapper $faceClusters, TagManager $tagManager, Logger $logger) { + public function __construct(FaceDetectionMapper $faceDetections, FaceClusterMapper $faceClusters, Logger $logger) { $this->faceDetections = $faceDetections; $this->faceClusters = $faceClusters; - $this->tagManager = $tagManager; $this->logger = $logger; } @@ -38,72 +37,88 @@ public function __construct(FaceDetectionMapper $faceDetections, FaceClusterMapp * @throws \JsonException */ public function calculateClusters(string $userId): void { - $this->logger->debug('Find face detection for use '.$userId); - $detections = $this->faceDetections->findByUserId($userId); + $this->logger->debug('ClusterDebug: Retrieving face detections for user ' . $userId); - $detections = array_values(array_filter($detections, fn ($detection) => + $unclusteredDetections = $this->faceDetections->findUnclusteredByUserId($userId); + + $unclusteredDetections = array_values(array_filter($unclusteredDetections, fn ($detection) => $detection->getHeight() > self::MIN_DETECTION_SIZE && $detection->getWidth() > self::MIN_DETECTION_SIZE )); - if (count($detections) === 0) { - $this->logger->debug('No face detections found'); + if (count($unclusteredDetections) < self::MIN_DATASET_SIZE) { + $this->logger->debug('ClusterDebug: Not enough face detections found'); return; } - $unclusteredDetections = $this->assignToExistingClusters($userId, $detections); + $this->logger->debug('ClusterDebug: Found ' . count($unclusteredDetections) . " unclustered detections. Calculating clusters."); - if (count($unclusteredDetections) === 0) { - $this->logger->debug('No unclustered face detections left after incremental run'); - return; - } + $sampledDetections = []; - // Here we use RubixMLs DBSCAN clustering algorithm - $dataset = new Unlabeled(array_map(function (FaceDetection $detection) : array { - return $detection->getVector(); - }, $unclusteredDetections)); - - $clusterer = new DBSCAN(self::MAX_INNER_CLUSTER_RADIUS, self::MIN_CLUSTER_DENSITY, new BallTree(100, new Euclidean())); - $this->logger->debug('Calculate clusters for '.count($unclusteredDetections).' faces'); - $results = $clusterer->predict($dataset); - $numClusters = max($results); - - $this->logger->debug('Found '.$numClusters.' new face clusters'); + $existingClusters = $this->faceClusters->findByUserId($userId); + foreach ($existingClusters as $existingCluster) { + $sampled = $this->faceDetections->findClusterSample($existingCluster->getId(), self::SAMPLE_SIZE_EXISTING_CLUSTERS); + $sampledDetections = array_merge($sampledDetections, $sampled); + } - for ($i = 0; $i <= $numClusters; $i++) { - $keys = array_keys($results, $i); - $clusterDetections = array_map(function ($key) use ($unclusteredDetections) : FaceDetection { - return $unclusteredDetections[$key]; - }, $keys); + $detections = array_merge($unclusteredDetections, $sampledDetections); - $cluster = new FaceCluster(); - $cluster->setTitle(''); - $cluster->setUserId($userId); - $this->faceClusters->insert($cluster); + $dataset = new Labeled(array_map(static function (FaceDetection $detection): array { + return $detection->getVector(); + }, $detections), array_combine(array_keys($detections), array_keys($detections)), false); + + $hdbscan = new HDBSCAN($dataset, self::MIN_CLUSTER_SIZE, self::MIN_SAMPLE_SIZE); + + $numberOfClusteredDetections = 0; + $clusters = $hdbscan->predict(); + + foreach ($clusters as $flatCluster) { + $detectionKeys = array_keys($flatCluster->getClusterVertices()); + $clusterCentroid = self::calculateCentroidOfDetections(array_map(static fn ($key) => $detections[$key], $detectionKeys)); + + /** + * @var int|false $detection + */ + $detection = current(array_filter($detectionKeys, static fn ($key) => $detections[$key]->getClusterId() !== null)); + + if ($detection !== false) { + $clusterId = $detections[$detection]->getClusterId(); + $cluster = $this->faceClusters->find($clusterId); + } else { + $cluster = new FaceCluster(); + $cluster->setTitle(''); + $cluster->setUserId($userId); + $this->faceClusters->insert($cluster); + } - $clusterCentroid = self::calculateCentroidOfDetections($clusterDetections); + foreach ($detectionKeys as $detectionKey) { + if ($detectionKey >= count($unclusteredDetections)) { + // This is a sampled, already clustered detection, ignore. + continue; + } - foreach ($clusterDetections as $detection) { // If threshold is larger than 0 and $clusterCentroid is not the null vector - if ($detection->getThreshold() > 0.0 && count(array_filter($clusterCentroid, fn ($el) => $el !== 0.0)) > 0) { + if ($unclusteredDetections[$detectionKey]->getThreshold() > 0.0 && count(array_filter($clusterCentroid, fn ($el) => $el !== 0.0)) > 0) { // If a threshold is set for this detection and its vector is farther away from the centroid // than the threshold, skip assigning this detection to the cluster - $distanceValue = self::distance($clusterCentroid, $detection->getVector()); - if ($distanceValue >= $detection->getThreshold()) { + $distanceValue = self::distance($clusterCentroid, $unclusteredDetections[$detectionKey]->getVector()); + if ($distanceValue >= $unclusteredDetections[$detectionKey]->getThreshold()) { continue; } } - $this->faceDetections->assocWithCluster($detection, $cluster); + $this->faceDetections->assocWithCluster($unclusteredDetections[$detectionKey], $cluster); + $numberOfClusteredDetections += 1; } } - $this->pruneDuplicateFilesFromClusters($userId); + $this->logger->debug('ClusterDebug: Clustering complete. Total num of clustered detections: ' . $numberOfClusteredDetections); + $this->pruneClusters($userId); } /** * @throws \OCP\DB\Exception */ - public function pruneDuplicateFilesFromClusters(string $userId): void { + public function pruneClusters(string $userId): void { $clusters = $this->faceClusters->findByUserId($userId); if (count($clusters) === 0) { @@ -124,7 +139,8 @@ public function pruneDuplicateFilesFromClusters(string $userId): void { foreach ($filesWithDuplicateFaces as $fileDetections) { $detectionsByDistance = []; foreach ($fileDetections as $detection) { - $detectionsByDistance[$detection->getId()] = self::distance($centroid, $detection->getVector()); + $distance = new Euclidean(); + $detectionsByDistance[$detection->getId()] = $distance->compute($centroid, $detection->getVector()); } asort($detectionsByDistance); $bestMatchingDetectionId = array_keys($detectionsByDistance)[0]; @@ -157,12 +173,12 @@ public static function calculateCentroidOfDetections(array $detections): array { } foreach ($detections as $detection) { - $sum = array_map(function ($el, $el2) { + $sum = array_map(static function ($el, $el2) { return $el + $el2; }, $detection->getVector(), $sum); } - $centroid = array_map(function ($el) use ($detections) { + $centroid = array_map(static function ($el) use ($detections) { return $el / count($detections); }, $sum); @@ -183,64 +199,13 @@ private function findFilesWithDuplicateFaces(array $detections): array { } /** @var array $filesWithDuplicateFaces */ - $filesWithDuplicateFaces = array_filter($files, function ($detections) { + $filesWithDuplicateFaces = array_filter($files, static function ($detections) { return count($detections) > 1; }); return $filesWithDuplicateFaces; } - /** - * @param string $userId - * @param list $detections - * @return list - * @throws \OCP\DB\Exception - */ - private function assignToExistingClusters(string $userId, array $detections): array { - $clusters = $this->faceClusters->findByUserId($userId); - - if (count($clusters) === 0) { - return $detections; - } - - $unclusteredDetections = []; - - foreach ($detections as $detection) { - $bestCluster = null; - $bestClusterDistance = 999; - if ($detection->getClusterId() !== null) { - continue; - } - foreach ($clusters as $cluster) { - $clusterDetections = $this->faceDetections->findByClusterId($cluster->getId()); - if (count($clusterDetections) > 50) { - $clusterDetections = array_map(fn ($key) => $clusterDetections[$key], array_rand($clusterDetections, 50)); - } - $clusterCentroid = self::calculateCentroidOfDetections($clusterDetections); - if ($detection->getThreshold() > 0 && self::distance($clusterCentroid, $detection->getVector()) >= $detection->getThreshold()) { - continue; - } - foreach ($clusterDetections as $clusterDetection) { - $distance = self::distance($clusterDetection->getVector(), $detection->getVector()); - if ( - $distance <= self::MAX_INNER_CLUSTER_RADIUS - && (!isset($bestCluster) || $distance < $bestClusterDistance) - ) { - $bestCluster = $cluster; - $bestClusterDistance = self::distance($clusterDetection->getVector(), $detection->getVector()); - break; - } - } - } - if ($bestCluster !== null) { - $this->faceDetections->assocWithCluster($detection, $bestCluster); - continue; - } - $unclusteredDetections[] = $detection; - } - return $unclusteredDetections; - } - private static ?Euclidean $distance; /** diff --git a/psalm-baseline.xml b/psalm-baseline.xml index f9cbccea..8672ec9d 100644 --- a/psalm-baseline.xml +++ b/psalm-baseline.xml @@ -185,6 +185,710 @@ getFileId + + + $distances + + + Ball + Stats::mean($values) + + + Stats::mean($values) + compute + compute + isContinuous + new self($center, $radius, $subsets) + spatialSplit + + + $longestDistance + + + $setId + + + $values + $values + + + $column + $leftCentroid + $leftCentroid + $rightCentroid + $sample + $sample + + + $retVal + $setId + $this->longestDistanceInNode + $values + + + null|int|string + null|int|string + + + $this->setId + $this->setId + $this->setId + + + propagateSetChanges + propagateSetChanges + resetFullyConnectedStatus + resetLongestEdge + + + + + $distances + + + Clique + Stats::mean($values) + + + Stats::mean($values) + compute + isContinuous + new self($dataset, $center, $radius) + + + $values + $values + + + $column + $sample + + + $labelToSetId[$label] + $labelToSetId[array_pop($labels)] + + + $labelToSetId[array_pop($labels)] + + + $label + $values + + + $labelToSetId[$this->dataset->label(0)] + + + + + EstimatorType::clusterer() + + + EstimatorType::clusterer() + Params::stringify($this->params()) + compatibility + + + $flatClusters + list<MstClusterer> + + + $dataset + + + + + center + center + center + center + cleanup + compute + compute + compute + compute + dataset + dataset + dataset + dataset + dataset + dataset + isPoint + radius + radius + radius + radius + radius + radius + radius + radius + radius + radius + radius + subsets + + + Labeled + + + $this->root + [$labels, $distances] + [$labels, $distances] + [array_keys($this->coreNeighborDistances[$sampleLabel]), array_values($this->coreNeighborDistances[$sampleLabel])] + + + $a + $b + $bestDistances + $k + $k + $label + $maxRange + $queryNode + $referenceNode + + + $bestDistances + $coreNeighborDistances + $coreNeighborDistances + $k + $label + $label + $largestOldCoreDistance + $longestDistance + $longestLeft + $longestRight + $maxRange + $oldDistances + $queryLeft->radius() + $queryNode + $queryNode->radius() + $queryNode->radius() + $queryNode->radius() + $queryRight->radius() + $querySample + $referenceNode + $referenceSample + $shortestDistance + $this->coreNeighborDistances[$label] + $this->coreNeighborDistances[$label] + $this->coreNeighborDistances[$label] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$sampleLabel] + $this->coreNeighborDistances[$sampleLabel] + reset($oldCoreNeighbors) + + + $a_vector + $b_vector + $sample + $sample + $sampleKey + $sampleKey + array_slice($neighborLabels, 0, $this->sampleSize) + + + $queryLabels[$queryKey] + $referenceLabels[$referenceKey] + + + $coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$neighborLabel][$label] + $this->coreNeighborDistances[$queryLabel][$referenceLabel] + $this->coreNeighborDistances[$queryLabel][$referenceLabel] + $this->coreNeighborDistances[$referenceLabel][$queryLabel] + $this->coreNeighborDistances[$referenceLabel][$queryLabel] + $this->nativeInterpointCache[$smallIndex][$largeIndex] + $this->nodeDistances[$smallIndex][$largeIndex] + + + $bestDistances[$label] + $bestDistances[$queryLabel] + $bestDistances[$queryLabel] + $coreNeighborDistances[$referenceLabel] + $oldCoreNeighbors[$label] + $oldCoreNeighbors[$neighborLabel] + $queryLabels[$queryKey] + $referenceLabels[$referenceKey] + $this->coreDistances[$label] + $this->coreDistances[$label] + $this->coreDistances[$label] + $this->coreDistances[$label] + $this->coreDistances[$neighborLabel] + $this->coreDistances[$queryLabel] + $this->coreDistances[$queryLabel] + $this->coreDistances[$queryLabel] + $this->coreDistances[$queryLabel] + $this->coreDistances[$referenceLabel] + $this->coreDistances[$referenceLabel] + $this->coreDistances[$referenceLabel] + $this->coreNeighborDistances[$a] + $this->coreNeighborDistances[$b] + $this->coreNeighborDistances[$label] + $this->coreNeighborDistances[$label] + $this->coreNeighborDistances[$neighborLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$queryLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->coreNeighborDistances[$referenceLabel] + $this->nativeInterpointCache[$smallIndex] + $this->nativeInterpointCache[$smallIndex] + $this->nodeDistances[$smallIndex] + $this->nodeDistances[$smallIndex] + $updatedOldCoreLabels[$neighborLabel] + + + $this->nodeDistances[$smallIndex] + + + $bestDistance + $bestDistances[$queryLabel] + $child + $coreDistance + $coreNeighborDistances + $currentBound + $currentDistance + $label + $label + $label + $labels[] + $labels[] + $largeIndex + $largeIndex + $largeIndex + $largestOldCoreDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestLeft + $longestLeft + $longestRight + $longestRight + $neighborLabel + $nodeDistance + $oldDistances + $queryKey + $queryLabel + $queryLabels + $queryLeft + $queryLeft + $queryNodeId + $queryRight + $queryRight + $querySample + $querySamples + $radius + $referenceKey + $referenceLabel + $referenceLabels + $referenceLeft + $referenceNodeId + $referenceRight + $referenceSample + $referenceSamples + $shortestDistance + $smallIndex + $smallIndex + $smallIndex + $subChild + $subStack[] + + + float + float + float + float + + + children + getLongestDistance + getLongestDistance + getLongestDistance + labels + labels + left + left + left + radius + radius + radius + radius + radius + radius + radius + radius + radius + radius + right + right + right + samples + samples + setLongestDistance + + + $longestLeft + $longestRight + $queryNode->radius() + $queryNode->radius() + $queryNode->radius() + $shortestDistance + min($longestLeft, $longestRight) + + + $nodeDistance + $this->coreDistances[$label] + $this->coreDistances[$queryLabel] ?? INF + $this->coreNeighborDistances[$a][$b] + $this->coreNeighborDistances[$b][$a] + $this->nativeInterpointCache[$smallIndex][$largeIndex] + + + DualTreeBall + array{list<mixed>,list<float>} + array{list<mixed>,list<float>} + + + $this->dataset + $this->root + + + $sampleKey + $sampleKey + + + labels + labels + labels + + + children + dataset + dataset + dataset + dataset + + + resetFullyConnectedStatus + resetLongestEdge + + + $labels + + + + + $childCluster->getCoreEdges() + $childClusterEdges1 + $childClusterEdges1 + $childClusterEdges2 + $childClusterEdges2 + $childClusterVerticesToEdges1 + $childClusterVerticesToEdges2 + $currentLambda + $this->coreEdges + $this->mapVerticesToEdges[$currentVertex] + $this->mapVerticesToEdges[$currentVertex] + $this->remainingEdges + $this->remainingEdges + $this->remainingEdges + $this->remainingEdges + $this->remainingEdges + $vertexConnectedFrom + $vertexConnectedTo + + + $a["distance"] + $a["distance"] + $b["distance"] + $b["distance"] + $currentLongestEdge["distance"] + $currentLongestEdge["vertexFrom"] + $currentLongestEdge["vertexTo"] + $edge["finalLambda"] + $edge["finalLambda"] + $edge["finalLambda"] + $edge["finalLambda"] + $edge["vertexFrom"] + $edge["vertexFrom"] + $edge["vertexFrom"] + $edge["vertexFrom"] + $edge["vertexTo"] + $edge["vertexTo"] + $edge["vertexTo"] + $edge["vertexTo"] + $edge['vertexFrom'] + $edge['vertexTo'] + $this->mapVerticesToEdges[$vertexConnectedFrom] + $this->mapVerticesToEdges[$vertexConnectedTo] + + + $edge['finalLambda'] + $this->edges[$currentLongestEdgeKey]["finalLambda"] + $this->edges[$edgeKey]['finalLambda'] + $this->edges[$edgeKey]['finalLambda'] + $this->edges[$edgeToPrune]['finalLambda'] + + + $mapVerticesToEdges[$edge['vertexFrom']] + $mapVerticesToEdges[$edge['vertexTo']] + $this->mapVerticesToEdges[$vertexConnectedFrom] + $this->mapVerticesToEdges[$vertexConnectedTo] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexTo"]] + $vertices[$edge["vertexTo"]] + $vertices[$edge["vertexTo"]] + $vertices[$edge["vertexTo"]] + + + $this->remainingEdges[$edgeKey] + $this->remainingEdges[$edgeKey] + $this->remainingEdges[$edgeToPrune] + + + $childCluster + $childrenWeight + $currentLambda + $currentLongestEdge + $edge + $edge + $edge + $edge + $edgeLength + $lastLambda + $this->clusterWeight + $this->coreEdges + $this->mapVerticesToEdges + $this->mapVerticesToEdges + $this->remainingEdges + $this->remainingEdges + $vertexConnectedFrom + $vertexConnectedTo + $verticesInCluster[$currentVertex] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexFrom"]] + $vertices[$edge["vertexTo"]] + $vertices[$edge["vertexTo"]] + + + getClusterWeight + getCoreEdges + + + $childCluster->getClusterWeight() + $currentLambda - $lastLambda + $edgeLength + $lastLambda + $this->clusterWeight + + + $this->edges + + + $this->edges + $this->mapVerticesToEdges[$vertexConnectedFrom] + $this->mapVerticesToEdges[$vertexConnectedTo] + + + + + $child + $queryLeft + $queryLeft + $queryLeft + $queryRight + $queryRight + $queryRight + $referenceLeft + $referenceLeft + $referenceRight + $referenceRight + + + radius + radius + radius + radius + radius + radius + + + $edges + + + $queryNode + $referenceNode + + + $label + $longestLeft + $longestRight + $queryLabel + $queryLeft->radius() + $queryNode->radius() + $queryRight->radius() + $querySample + $referenceLabel + $referenceSample + $shortestDistance + + + $allLabels + + + $connectingEdge["vertexFrom"] + $connectingEdge["vertexTo"] + $newEdges[$queryNode->getSetId()]["distance"] + $queryLabels[$queryKey] + $referenceLabels[$referenceKey] + + + $newEdges[$querySetId] + $newEdges[$querySetId] + $newEdges[$querySetId] + $newEdges[$querySetId] + $queryLabels[$queryKey] + $referenceLabels[$referenceKey] + $vertexSets[$setId1] + $vertexSets[$setId1] + $vertexSets[$setId1] + $vertexSets[$setId2] + $vertexSets[$setId2] + $vertexToSetId[$connectingEdge["vertexFrom"]] + $vertexToSetId[$connectingEdge["vertexTo"]] + $vertexToSetId[$queryLabel] + $vertexToSetId[$referenceLabel] + $vertexToSetId[$vertexLabel] + + + $vertexSets[$setId1] + $vertexSets[$setId1] + $vertexSets[$setId2] + $vertexSets[$setId2] + + + $candidateDist + $connectingEdge + $currentBound + $edges[] + $label + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestDistance + $longestLeft + $longestLeft + $longestRight + $longestRight + $neighborLabel + $queryKey + $queryLabel + $queryLabels + $querySample + $querySamples + $querySetId + $referenceKey + $referenceLabel + $referenceLabels + $referenceSample + $referenceSamples + $referenceSetId + $setId1 + $setId1 + $setId2 + $setId2 + $shortestDistance + $vertexLabel + $vertexToSetId[$vertexLabel] + + + dataset + dataset + dataset + dataset + labels + labels + radius + radius + samples + samples + setLongestDistance + + + $queryNode->radius() + 2 * $queryNode->radius() + min($longestLeft, $longestRight) + + + list<array{vertexFrom:int|string,vertexTo:int|string,distance:float}> + + + $queryLeft + $queryLeft + $queryLeft + $queryRight + $queryRight + $queryRight + $referenceLeft + $referenceLeft + $referenceRight + $referenceRight + + + $newEdges + + + getLongestDistance + getLongestDistance + radius + radius + radius + radius + + + getLongestDistance + getLongestDistance + radius + radius + radius + radius + + + + + DataType::continuous() + + + DataType::continuous() + + + $a + $b + + $value @@ -272,7 +976,8 @@ - + + $this->findEntities($qb) $this->findEntities($qb) $this->findEntities($qb) $this->findEntities($qb) @@ -280,13 +985,17 @@ $this->findEntity($qb) $this->findEntity($qb) + + findClusterSample + $qb->executeQuery()->fetchAll(\PDO::FETCH_COLUMN) list<string> - + FaceDetection FaceDetection + \OCA\Recognize\Db\FaceDetection[] list<\OCA\Recognize\Db\FaceDetection> list<\OCA\Recognize\Db\FaceDetection> list<\OCA\Recognize\Db\FaceDetection> @@ -477,15 +1186,28 @@ - - $results - - + + compute compute - - $detection->getClusterId() !== null - + + $clusterId + $sampled + + + array_map(static fn ($key) => $detections[$key], $detectionKeys) + + + $unclusteredDetections[$detectionKey] + + + $clusterId + $sampled + + + getClusterId + getClusterId + diff --git a/tests/ClusterTest.php b/tests/ClusterTest.php index 9d4565fc..2ff11eb0 100644 --- a/tests/ClusterTest.php +++ b/tests/ClusterTest.php @@ -4,7 +4,9 @@ use OCA\Recognize\Db\FaceDetection; use OCA\Recognize\Db\FaceDetectionMapper; use OCA\Recognize\Service\FaceClusterAnalyzer; +use OCA\Recognize\Service\Logger; use Rubix\ML\Kernels\Distance\Euclidean; +use Symfony\Component\Console\Output\OutputInterface; use Test\TestCase; /** @@ -31,6 +33,12 @@ public function setUp(): void { $this->faceClusterAnalyzer = \OC::$server->get(FaceClusterAnalyzer::class); $this->faceClusterMapper = \OC::$server->get(FaceClusterMapper::class); + $logger = \OC::$server->get(Logger::class); + $cliOutput = $this->createMock(OutputInterface::class); + $cliOutput->method('writeln') + ->willReturnCallback(fn ($msg) => print($msg."\n")); + $logger->setCliOutput($cliOutput); + $clusters = $this->faceClusterMapper->findByUserId(self::TEST_USER1); foreach ($clusters as $cluster) { $this->faceClusterMapper->delete($cluster); @@ -193,15 +201,17 @@ public function testClusterMerging2() { $this->faceDetectionMapper->update($detection); } - $newDetection = new FaceDetection(); - $newDetection->setHeight(0.5); - $newDetection->setWidth(0.5); - $newDetection->setFileId(500000); - $nullVector = self::getNullVector(); - $nullVector[0] = 0.8; - $newDetection->setVector($nullVector); - $newDetection->setUserId(self::TEST_USER1); - $this->faceDetectionMapper->insert($newDetection); + for ($i = 0; $i < self::INITIAL_DETECTIONS_PER_CLUSTER; $i++) { + $newDetection = new FaceDetection(); + $newDetection->setHeight(0.5); + $newDetection->setWidth(0.5); + $newDetection->setFileId(500000 + $i); + $nullVector = self::getNullVector(); + $nullVector[0] = 1 + 0.001 * $i; + $newDetection->setVector($nullVector); + $newDetection->setUserId(self::TEST_USER1); + $this->faceDetectionMapper->insert($newDetection); + } $this->faceClusterAnalyzer->calculateClusters(self::TEST_USER1); @@ -210,7 +220,7 @@ public function testClusterMerging2() { self::assertCount(1, $clusters); $detections = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER * 2 + 1, $detections); + self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER * 3, $detections); } /** @@ -238,6 +248,7 @@ public function testClusterMergingWithThirdAdditionalCluster() { $detection->setClusterId($clusters[0]->getId()); $this->faceDetectionMapper->update($detection); } + $this->faceClusterMapper->delete($clusters[1]); $numOfDetections = self::INITIAL_DETECTIONS_PER_CLUSTER; $clusterValue = 3 * self::INITIAL_DETECTIONS_PER_CLUSTER; @@ -259,11 +270,11 @@ public function testClusterMergingWithThirdAdditionalCluster() { $clusters = $this->faceClusterMapper->findByUserId(self::TEST_USER1); self::assertCount(2, $clusters); - $detections = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER * 2, $detections); - - $detections = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER, $detections); + $detections1 = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); + $detections2 = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); + $counts = [count($detections1), count($detections2)]; + self::assertCount(1, array_filter($counts, fn ($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER), var_export($counts, true)); + self::assertCount(1, array_filter($counts, fn ($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER * 2), var_export($counts, true)); } /** @@ -323,16 +334,15 @@ public function testClusterTemptClusterMerging() { /** @var \OCA\Recognize\Db\FaceCluster[] $clusters */ $clusters = $this->faceClusterMapper->findByUserId(self::TEST_USER1); - self::assertCount(3, $clusters); + self::assertCount(4, $clusters); - $detections = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER * 2, $detections); + $detections1 = $this->faceDetectionMapper->findByClusterId($clusters[0]->getId()); + $detections2 = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); + $detections3 = $this->faceDetectionMapper->findByClusterId($clusters[2]->getId()); + $detections4 = $this->faceDetectionMapper->findByClusterId($clusters[3]->getId()); + $counts = [count($detections1), count($detections2), count($detections3), count($detections4)]; - $detections = $this->faceDetectionMapper->findByClusterId($clusters[1]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER, $detections); - - $detections = $this->faceDetectionMapper->findByClusterId($clusters[2]->getId()); - self::assertCount(self::INITIAL_DETECTIONS_PER_CLUSTER, $detections); + self::assertCount(4, array_filter($counts, fn ($count) => $count === self::INITIAL_DETECTIONS_PER_CLUSTER), var_export($counts, true)); } private static function getNullVector() {