diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index 44bbca9e6f..9118c054b7 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -11,6 +11,7 @@ use OCA\Mail\Account; use OCA\Mail\Db\Message; +use OCA\Mail\Service\Classification\ImportanceClassifier; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Transformers\MultibyteTextNormalizer; @@ -21,13 +22,15 @@ use function array_map; class SubjectExtractor implements IExtractor { + private const MAX_VOCABULARY_SIZE = 500; + private WordCountVectorizer $wordCountVectorizer; private TfIdfTransformer $tfidf; private int $max = -1; public function __construct() { // Limit vocabulary to limit memory usage - $this->wordCountVectorizer = new WordCountVectorizer(500); + $this->wordCountVectorizer = new WordCountVectorizer(self::MAX_VOCABULARY_SIZE); $this->tfidf = new TfIdfTransformer(); } @@ -53,15 +56,14 @@ public function prepare(Account $account, array $incomingMailboxes, array $outgo $data = array_map(static function (Message $message) { return [ 'text' => $message->getSubject() ?? '', - 'label' => $message->getFlagImportant() ? 'i' : 'ni', + 'label' => $message->getFlagImportant() + ? ImportanceClassifier::LABEL_IMPORTANT + : ImportanceClassifier::LABEL_NOT_IMPORTANT, ]; }, $messages); // Fit transformers - Labeled::build( - array_column($data, 'text'), - array_column($data, 'label'), - ) + Labeled::build(array_column($data, 'text'), array_column($data, 'label')) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) ->apply($this->tfidf); @@ -104,6 +106,5 @@ private function limitFeatureSize(): void { } $this->max = count($vocabularies[0]); - echo("WCF vocab size: {$this->max}\n"); } } diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index fea58e4556..7211d826ba 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -72,12 +72,12 @@ class ImportanceClassifier { /** * @var string label for data sets that are classified as important */ - private const LABEL_IMPORTANT = 'i'; + public const LABEL_IMPORTANT = 'i'; /** * @var string label for data sets that are classified as not important */ - private const LABEL_NOT_IMPORTANT = 'ni'; + public const LABEL_NOT_IMPORTANT = 'ni'; /** * The minimum number of important messages. Without those the unsupervised