diff --git a/appinfo/info.xml b/appinfo/info.xml index 6bb3d39258..90e007313b 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -34,7 +34,7 @@ The rating depends on the installed text processing backend. See [the rating ove Learn more about the Nextcloud Ethical AI Rating [in our blog](https://nextcloud.com/blog/nextcloud-ethical-ai-rating/). ]]> - 4.2.0-alpha.0 + 4.2.0-alpha.1 agpl Christoph Wurst GretaD @@ -90,6 +90,7 @@ Learn more about the Nextcloud Ethical AI Rating [in our blog](https://nextcloud OCA\Mail\Command\TrainAccount OCA\Mail\Command\UpdateAccount OCA\Mail\Command\UpdateSystemAutoresponders + OCA\Mail\Command\RunMetaEstimator OCA\Mail\Settings\AdminSettings diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index 01bea26944..216d489af4 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -45,7 +45,6 @@ use OCA\Mail\Listener\MessageCacheUpdaterListener; use OCA\Mail\Listener\MessageKnownSinceListener; use OCA\Mail\Listener\MoveJunkListener; -use OCA\Mail\Listener\NewMessageClassificationListener; use OCA\Mail\Listener\NewMessagesNotifier; use OCA\Mail\Listener\OauthTokenRefreshListener; use OCA\Mail\Listener\OptionalIndicesListener; @@ -130,7 +129,6 @@ public function register(IRegistrationContext $context): void { $context->registerEventListener(MessageDeletedEvent::class, MessageCacheUpdaterListener::class); $context->registerEventListener(MessageSentEvent::class, AddressCollectionListener::class); $context->registerEventListener(MessageSentEvent::class, InteractionListener::class); - $context->registerEventListener(NewMessagesSynchronized::class, NewMessageClassificationListener::class); $context->registerEventListener(NewMessagesSynchronized::class, MessageKnownSinceListener::class); $context->registerEventListener(NewMessagesSynchronized::class, NewMessagesNotifier::class); $context->registerEventListener(SynchronizationEvent::class, AccountSynchronizedThreadUpdaterListener::class); diff --git a/lib/BackgroundJob/TrainImportanceClassifierJob.php b/lib/BackgroundJob/TrainImportanceClassifierJob.php index d6482cbe62..0ec467a1a1 100644 --- a/lib/BackgroundJob/TrainImportanceClassifierJob.php +++ b/lib/BackgroundJob/TrainImportanceClassifierJob.php @@ -69,10 +69,7 @@ protected function run($argument) { } try { - $this->classifier->train( - $account, - $this->logger - ); + $this->classifier->train($account, $this->logger); } catch (Throwable $e) { $this->logger->error('Cron importance classifier training failed: ' . $e->getMessage(), [ 'exception' => $e, diff --git a/lib/Command/PredictImportance.php b/lib/Command/PredictImportance.php index 386dc413ea..fcaa458ec2 100644 --- a/lib/Command/PredictImportance.php +++ b/lib/Command/PredictImportance.php @@ -13,6 +13,7 @@ use OCA\Mail\Db\Message; use OCA\Mail\Service\AccountService; use OCA\Mail\Service\Classification\ImportanceClassifier; +use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; use OCP\IConfig; use Psr\Log\LoggerInterface; @@ -25,6 +26,7 @@ class PredictImportance extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; public const ARGUMENT_SENDER = 'sender'; + public const ARGUMENT_SUBJECT = 'subject'; private AccountService $accountService; private ImportanceClassifier $classifier; @@ -43,26 +45,27 @@ public function __construct(AccountService $service, $this->config = $config; } - /** - * @return void - */ - protected function configure() { + protected function configure(): void { $this->setName('mail:predict-importance'); $this->setDescription('Predict importance of an incoming message'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); $this->addArgument(self::ARGUMENT_SENDER, InputArgument::REQUIRED); + $this->addArgument(self::ARGUMENT_SUBJECT, InputArgument::OPTIONAL); } - public function isEnabled() { + public function isEnabled(): bool { return $this->config->getSystemValueBool('debug'); } - /** - * @return int - */ protected function execute(InputInterface $input, OutputInterface $output): int { $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); $sender = $input->getArgument(self::ARGUMENT_SENDER); + $subject = $input->getArgument(self::ARGUMENT_SUBJECT) ?? ''; + + $consoleLogger = new ConsoleLoggerDecorator( + $this->logger, + $output + ); try { $account = $this->accountService->findById($accountId); @@ -73,9 +76,11 @@ protected function execute(InputInterface $input, OutputInterface $output): int $fakeMessage = new Message(); $fakeMessage->setUid(0); $fakeMessage->setFrom(AddressList::parse("Name <$sender>")); + $fakeMessage->setSubject($subject); [$prediction] = $this->classifier->classifyImportance( $account, - [$fakeMessage] + [$fakeMessage], + $consoleLogger ); if ($prediction) { $output->writeln('Message is important'); diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php new file mode 100644 index 0000000000..b23a55ad4a --- /dev/null +++ b/lib/Command/RunMetaEstimator.php @@ -0,0 +1,117 @@ +accountService = $accountService; + $this->logger = $logger; + $this->classifier = $classifier; + $this->config = $config; + } + + protected function configure(): void { + $this->setName('mail:account:run-meta-estimator'); + $this->setDescription('Run the meta estimator for an account'); + $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); + $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training'); + } + + public function isEnabled(): bool { + return $this->config->getSystemValueBool('debug'); + } + + protected function execute(InputInterface $input, OutputInterface $output): int { + $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); + $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE); + + try { + $account = $this->accountService->findById($accountId); + } catch (DoesNotExistException $e) { + $output->writeln("Account $accountId does not exist"); + return 1; + } + + $consoleLogger = new ConsoleLoggerDecorator( + $this->logger, + $output + ); + + $estimator = static function () use ($consoleLogger) { + $params = [ + [5, 10, 15, 20, 25, 30, 35, 40], // Neighbors + [true, false], // Weighted? + [new Euclidean(), new Manhattan(), new Jaccard()], // Kernel + ]; + + $estimator = new GridSearch( + KNearestNeighbors::class, + $params, + new FBeta(), + new KFold(5) + ); + $estimator->setLogger($consoleLogger); + $estimator->setBackend(new Amp()); + return $estimator; + }; + + $pipeline = $this->classifier->train( + $account, + $consoleLogger, + $estimator, + $shuffle, + false, + ); + + /** @var GridSearch $metaEstimator */ + $metaEstimator = $pipeline?->getEstimator(); + if ($metaEstimator !== null) { + $output->writeln("Best estimator: {$metaEstimator->base()}"); + } + + $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); + $output->writeln('' . $mbs . 'MB of memory used'); + return 0; + } +} diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index a2d2b09cc4..19fef4ad0d 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -3,7 +3,7 @@ declare(strict_types=1); /** - * SPDX-FileCopyrightText: 2019 Nextcloud GmbH and Nextcloud contributors + * SPDX-FileCopyrightText: 2019-2024 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ @@ -23,6 +23,9 @@ class TrainAccount extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; + public const ARGUMENT_SHUFFLE = 'shuffle'; + public const ARGUMENT_DRY_RUN = 'dry-run'; + public const ARGUMENT_FORCE = 'force'; private AccountService $accountService; private ImportanceClassifier $classifier; @@ -41,20 +44,30 @@ public function __construct(AccountService $service, $this->classificationSettingsService = $classificationSettingsService; } - /** - * @return void - */ - protected function configure() { + protected function configure(): void { $this->setName('mail:account:train'); $this->setDescription('Train the classifier of new messages'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); + $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training'); + $this->addOption( + self::ARGUMENT_DRY_RUN, + null, + null, + 'Don\'t persist classifier after training' + ); + $this->addOption( + self::ARGUMENT_FORCE, + null, + null, + 'Train an estimator even if the classification is disabled by the user' + ); } - /** - * @return int - */ protected function execute(InputInterface $input, OutputInterface $output): int { $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); + $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE); + $dryRun = (bool)$input->getOption(self::ARGUMENT_DRY_RUN); + $force = (bool)$input->getOption(self::ARGUMENT_FORCE); try { $account = $this->accountService->findById($accountId); @@ -62,7 +75,8 @@ protected function execute(InputInterface $input, OutputInterface $output): int $output->writeln("account $accountId does not exist"); return 1; } - if (!$this->classificationSettingsService->isClassificationEnabled($account->getUserId())) { + + if (!$force && !$this->classificationSettingsService->isClassificationEnabled($account->getUserId())) { $output->writeln("classification is turned off for account $accountId"); return 2; } @@ -71,9 +85,13 @@ protected function execute(InputInterface $input, OutputInterface $output): int $this->logger, $output ); + $this->classifier->train( $account, - $consoleLogger + $consoleLogger, + null, + $shuffle, + !$dryRun ); $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); diff --git a/lib/Db/Classifier.php b/lib/Db/Classifier.php deleted file mode 100644 index dadbfaa4f6..0000000000 --- a/lib/Db/Classifier.php +++ /dev/null @@ -1,103 +0,0 @@ -addType('accountId', 'integer'); - $this->addType('type', 'string'); - $this->addType('appVersion', 'string'); - $this->addType('trainingSetSize', 'integer'); - $this->addType('validationSetSize', 'integer'); - $this->addType('recallImportant', 'float'); - $this->addType('precisionImportant', 'float'); - $this->addType('f1ScoreImportant', 'float'); - $this->addType('duration', 'integer'); - $this->addType('active', 'boolean'); - $this->addType('createdAt', 'integer'); - } -} diff --git a/lib/Db/ClassifierMapper.php b/lib/Db/ClassifierMapper.php deleted file mode 100644 index 946b70f8b0..0000000000 --- a/lib/Db/ClassifierMapper.php +++ /dev/null @@ -1,57 +0,0 @@ - - */ -class ClassifierMapper extends QBMapper { - public function __construct(IDBConnection $db) { - parent::__construct($db, 'mail_classifiers'); - } - - /** - * @param int $id - * - * @return Classifier - * @throws DoesNotExistException - */ - public function findLatest(int $id): Classifier { - $qb = $this->db->getQueryBuilder(); - - $select = $qb->select('*') - ->from($this->getTableName()) - ->where( - $qb->expr()->eq('account_id', $qb->createNamedParameter($id, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT), - $qb->expr()->eq('active', $qb->createNamedParameter(true, IQueryBuilder::PARAM_BOOL), IQueryBuilder::PARAM_BOOL) - ) - ->orderBy('created_at', 'desc') - ->setMaxResults(1); - - return $this->findEntity($select); - } - - public function findHistoric(int $threshold, int $limit) { - $qb = $this->db->getQueryBuilder(); - $select = $qb->select('*') - ->from($this->getTableName()) - ->where( - $qb->expr()->lte('created_at', $qb->createNamedParameter($threshold, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT), - ) - ->orderBy('created_at', 'asc') - ->setMaxResults($limit); - return $this->findEntities($select); - } -} diff --git a/lib/Listener/NewMessageClassificationListener.php b/lib/Listener/NewMessageClassificationListener.php deleted file mode 100644 index 01c0ca5f4f..0000000000 --- a/lib/Listener/NewMessageClassificationListener.php +++ /dev/null @@ -1,119 +0,0 @@ - - */ -class NewMessageClassificationListener implements IEventListener { - private const EXEMPT_FROM_CLASSIFICATION = [ - Horde_Imap_Client::SPECIALUSE_ARCHIVE, - Horde_Imap_Client::SPECIALUSE_DRAFTS, - Horde_Imap_Client::SPECIALUSE_JUNK, - Horde_Imap_Client::SPECIALUSE_SENT, - Horde_Imap_Client::SPECIALUSE_TRASH, - ]; - - /** @var ImportanceClassifier */ - private $classifier; - - /** @var TagMapper */ - private $tagMapper; - - /** @var LoggerInterface */ - private $logger; - - /** @var IMailManager */ - private $mailManager; - - private ClassificationSettingsService $classificationSettingsService; - - public function __construct(ImportanceClassifier $classifier, - TagMapper $tagMapper, - LoggerInterface $logger, - IMailManager $mailManager, - ClassificationSettingsService $classificationSettingsService) { - $this->classifier = $classifier; - $this->logger = $logger; - $this->tagMapper = $tagMapper; - $this->mailManager = $mailManager; - $this->classificationSettingsService = $classificationSettingsService; - } - - public function handle(Event $event): void { - if (!($event instanceof NewMessagesSynchronized)) { - return; - } - - if (!$this->classificationSettingsService->isClassificationEnabled($event->getAccount()->getUserId())) { - return; - } - - foreach (self::EXEMPT_FROM_CLASSIFICATION as $specialUse) { - if ($event->getMailbox()->isSpecialUse($specialUse)) { - // Nothing to do then - return; - } - } - - $messages = $event->getMessages(); - - // if this is a message that's been flagged / tagged as important before, we don't want to reclassify it again. - $doNotReclassify = $this->tagMapper->getTaggedMessageIdsForMessages( - $event->getMessages(), - $event->getAccount()->getUserId(), - Tag::LABEL_IMPORTANT - ); - $messages = array_filter($messages, static function ($message) use ($doNotReclassify) { - return ($message->getFlagImportant() === false || in_array($message->getMessageId(), $doNotReclassify, true)); - }); - - try { - $important = $this->tagMapper->getTagByImapLabel(Tag::LABEL_IMPORTANT, $event->getAccount()->getUserId()); - } catch (DoesNotExistException $e) { - // just in case - if we get here, the tag is missing - $this->logger->error('Could not find important tag for ' . $event->getAccount()->getUserId() . ' ' . $e->getMessage(), [ - 'exception' => $e, - ]); - return; - } - - try { - $predictions = $this->classifier->classifyImportance( - $event->getAccount(), - $messages - ); - - foreach ($event->getMessages() as $message) { - if ($predictions[$message->getUid()] ?? false) { - $this->mailManager->flagMessage($event->getAccount(), $event->getMailbox()->getName(), $message->getUid(), Tag::LABEL_IMPORTANT, true); - $this->mailManager->tagMessage($event->getAccount(), $event->getMailbox()->getName(), $message, $important, true); - } - } - } catch (ServiceException $e) { - $this->logger->error('Could not classify incoming message importance: ' . $e->getMessage(), [ - 'exception' => $e, - ]); - } - } -} diff --git a/lib/Migration/Version4100Date20241021091352.php b/lib/Migration/Version4100Date20241021091352.php new file mode 100644 index 0000000000..ccdf8be264 --- /dev/null +++ b/lib/Migration/Version4100Date20241021091352.php @@ -0,0 +1,29 @@ +dropTable('mail_classifiers'); + return $schema; + } +} diff --git a/lib/Model/Classifier.php b/lib/Model/Classifier.php new file mode 100644 index 0000000000..df4d21eeb1 --- /dev/null +++ b/lib/Model/Classifier.php @@ -0,0 +1,134 @@ +accountId; + } + + public function setAccountId(int $accountId): void { + $this->accountId = $accountId; + } + + public function getType(): string { + return $this->type; + } + + public function setType(string $type): void { + $this->type = $type; + } + + public function getEstimator(): string { + return $this->estimator; + } + + public function setEstimator(string $estimator): void { + $this->estimator = $estimator; + } + + public function getPersistenceVersion(): int { + return $this->persistenceVersion; + } + + public function setPersistenceVersion(int $persistenceVersion): void { + $this->persistenceVersion = $persistenceVersion; + } + + public function getTrainingSetSize(): int { + return $this->trainingSetSize; + } + + public function setTrainingSetSize(int $trainingSetSize): void { + $this->trainingSetSize = $trainingSetSize; + } + + public function getValidationSetSize(): int { + return $this->validationSetSize; + } + + public function setValidationSetSize(int $validationSetSize): void { + $this->validationSetSize = $validationSetSize; + } + + public function getRecallImportant(): float { + return $this->recallImportant; + } + + public function setRecallImportant(float $recallImportant): void { + $this->recallImportant = $recallImportant; + } + + public function getPrecisionImportant(): float { + return $this->precisionImportant; + } + + public function setPrecisionImportant(float $precisionImportant): void { + $this->precisionImportant = $precisionImportant; + } + + public function getF1ScoreImportant(): float { + return $this->f1ScoreImportant; + } + + public function setF1ScoreImportant(float $f1ScoreImportant): void { + $this->f1ScoreImportant = $f1ScoreImportant; + } + + public function getDuration(): int { + return $this->duration; + } + + public function setDuration(int $duration): void { + $this->duration = $duration; + } + + public function getCreatedAt(): int { + return $this->createdAt; + } + + public function setCreatedAt(int $createdAt): void { + $this->createdAt = $createdAt; + } + + #[ReturnTypeWillChange] + public function jsonSerialize() { + return [ + 'accountId' => $this->accountId, + 'type' => $this->type, + 'estimator' => $this->estimator, + 'persistenceVersion' => $this->persistenceVersion, + 'trainingSetSize' => $this->trainingSetSize, + 'validationSetSize' => $this->validationSetSize, + 'recallImportant' => $this->recallImportant, + 'precisionImportant' => $this->precisionImportant, + 'f1ScoreImportant' => $this->f1ScoreImportant, + 'duration' => $this->duration, + 'createdAt' => $this->createdAt, + ]; + } +} diff --git a/lib/Model/ClassifierPipeline.php b/lib/Model/ClassifierPipeline.php new file mode 100644 index 0000000000..f58d40ba89 --- /dev/null +++ b/lib/Model/ClassifierPipeline.php @@ -0,0 +1,29 @@ +estimator; + } + + public function getExtractor(): IExtractor { + return $this->extractor; + } +} diff --git a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php index eefadd92ba..54469bc220 100644 --- a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php @@ -3,7 +3,7 @@ declare(strict_types=1); /** - * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors + * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ @@ -17,18 +17,25 @@ * Combines a set of DI'ed extractors so they can be used as one class */ class CompositeExtractor implements IExtractor { + private readonly SubjectExtractor $subjectExtractor; + /** @var IExtractor[] */ - private $extractors; + private readonly array $extractors; - public function __construct(ImportantMessagesExtractor $ex1, + public function __construct( + ImportantMessagesExtractor $ex1, ReadMessagesExtractor $ex2, RepliedMessagesExtractor $ex3, - SentMessagesExtractor $ex4) { + SentMessagesExtractor $ex4, + SubjectExtractor $ex5, + ) { + $this->subjectExtractor = $ex5; $this->extractors = [ $ex1, $ex2, $ex3, $ex4, + $ex5, ]; } @@ -46,4 +53,8 @@ public function extract(Message $message): array { return $extractor->extract($message); }, $this->extractors); } + + public function getSubjectExtractor(): SubjectExtractor { + return $this->subjectExtractor; + } } diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php new file mode 100644 index 0000000000..9118c054b7 --- /dev/null +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -0,0 +1,110 @@ +wordCountVectorizer = new WordCountVectorizer(self::MAX_VOCABULARY_SIZE); + $this->tfidf = new TfIdfTransformer(); + } + + public function getWordCountVectorizer(): WordCountVectorizer { + return $this->wordCountVectorizer; + } + + public function setWordCountVectorizer(WordCountVectorizer $wordCountVectorizer): void { + $this->wordCountVectorizer = $wordCountVectorizer; + $this->limitFeatureSize(); + } + + public function getTfIdf(): TfIdfTransformer { + return $this->tfidf; + } + + public function setTfidf(TfIdfTransformer $tfidf): void { + $this->tfidf = $tfidf; + } + + public function prepare(Account $account, array $incomingMailboxes, array $outgoingMailboxes, array $messages): void { + /** @var array> $data */ + $data = array_map(static function (Message $message) { + return [ + 'text' => $message->getSubject() ?? '', + 'label' => $message->getFlagImportant() + ? ImportanceClassifier::LABEL_IMPORTANT + : ImportanceClassifier::LABEL_NOT_IMPORTANT, + ]; + }, $messages); + + // Fit transformers + Labeled::build(array_column($data, 'text'), array_column($data, 'label')) + ->apply(new MultibyteTextNormalizer()) + ->apply($this->wordCountVectorizer) + ->apply($this->tfidf); + + $this->limitFeatureSize(); + } + + public function extract(Message $message): array { + $sender = $message->getFrom()->first(); + if ($sender === null) { + throw new RuntimeException('This should not happen'); + } + + // Build training data set + $trainText = $message->getSubject() ?? ''; + + $trainDataSet = Unlabeled::build([[$trainText]]) + ->apply(new MultibyteTextNormalizer()) + ->apply($this->wordCountVectorizer) + ->apply($this->tfidf); + + // Use zeroed vector if no features could be extracted + if ($trainDataSet->numFeatures() === 0) { + $textFeatures = array_fill(0, $this->max, 0); + } else { + $textFeatures = $trainDataSet->sample(0); + } + + return $textFeatures; + } + + /** + * Limit feature vector length to actual size of vocabulary. + */ + private function limitFeatureSize(): void { + $vocabularies = $this->wordCountVectorizer->vocabularies(); + if (!isset($vocabularies[0])) { + // Should not happen but better safe than sorry + return; + } + + $this->max = count($vocabularies[0]); + } +} diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 434d0c5bcc..0368617308 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -3,30 +3,39 @@ declare(strict_types=1); /** - * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors + * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Service\Classification; +use Closure; use Horde_Imap_Client; use OCA\Mail\Account; -use OCA\Mail\Db\Classifier; use OCA\Mail\Db\Mailbox; use OCA\Mail\Db\MailboxMapper; use OCA\Mail\Db\Message; use OCA\Mail\Db\MessageMapper; use OCA\Mail\Exception\ClassifierTrainingException; use OCA\Mail\Exception\ServiceException; +use OCA\Mail\Model\Classifier; +use OCA\Mail\Model\ClassifierPipeline; use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; use OCA\Mail\Support\PerformanceLogger; +use OCA\Mail\Support\PerformanceLoggerTask; use OCP\AppFramework\Db\DoesNotExistException; +use Psr\Container\ContainerExceptionInterface; +use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; -use Rubix\ML\Classifiers\GaussianNB; +use Rubix\ML\Classifiers\KNearestNeighbors; use Rubix\ML\CrossValidation\Reports\MulticlassBreakdown; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; +use Rubix\ML\Kernels\Distance\Manhattan; +use Rubix\ML\Learner; +use Rubix\ML\Persistable; use RuntimeException; use function array_column; use function array_combine; @@ -64,12 +73,12 @@ class ImportanceClassifier { /** * @var string label for data sets that are classified as important */ - private const LABEL_IMPORTANT = 'i'; + public const LABEL_IMPORTANT = 'i'; /** * @var string label for data sets that are classified as not important */ - private const LABEL_NOT_IMPORTANT = 'ni'; + public const LABEL_NOT_IMPORTANT = 'ni'; /** * The minimum number of important messages. Without those the unsupervised @@ -81,7 +90,7 @@ class ImportanceClassifier { /** * The maximum number of data sets to train the classifier with */ - private const MAX_TRAINING_SET_SIZE = 1000; + private const MAX_TRAINING_SET_SIZE = 300; /** @var MailboxMapper */ private $mailboxMapper; @@ -89,9 +98,6 @@ class ImportanceClassifier { /** @var MessageMapper */ private $messageMapper; - /** @var CompositeExtractor */ - private $extractor; - /** @var PersistenceService */ private $persistenceService; @@ -101,22 +107,40 @@ class ImportanceClassifier { /** @var ImportanceRulesClassifier */ private $rulesClassifier; - private LoggerInterface $logger; + private ContainerInterface $container; public function __construct(MailboxMapper $mailboxMapper, MessageMapper $messageMapper, - CompositeExtractor $extractor, PersistenceService $persistenceService, PerformanceLogger $performanceLogger, ImportanceRulesClassifier $rulesClassifier, - LoggerInterface $logger) { + ContainerInterface $container) { $this->mailboxMapper = $mailboxMapper; $this->messageMapper = $messageMapper; - $this->extractor = $extractor; $this->persistenceService = $persistenceService; $this->performanceLogger = $performanceLogger; $this->rulesClassifier = $rulesClassifier; - $this->logger = $logger; + $this->container = $container; + } + + private static function createDefaultEstimator(): KNearestNeighbors { + // A meta estimator was trained on the same data multiple times to average out the + // variance of the trained model. + // Parameters were chosen from the best configuration across 100 runs. + // Both variance (spread) and f1 score were considered. + // Note: Lower k values yield slightly higher f1 scores but show higher variances. + return new KNearestNeighbors(15, true, new Manhattan()); + } + + /** + * @throws ServiceException If the extractor is not available + */ + private function createExtractor(): CompositeExtractor { + try { + return $this->container->get(CompositeExtractor::class); + } catch (ContainerExceptionInterface $e) { + throw new ServiceException('Default extractor is not available', 0, $e); + } } private function filterMessageHasSenderEmail(Message $message): bool { @@ -124,22 +148,24 @@ private function filterMessageHasSenderEmail(Message $message): bool { } /** - * Train an account's classifier of important messages - * - * Train a classifier based on a user's existing messages to be able to derive - * importance markers for new incoming messages. - * - * To factor in (server-side) filtering into multiple mailboxes, the algorithm - * will not only look for messages in the inbox but also other non-special - * mailboxes. - * - * To prevent memory exhaustion, the process will only load a fixed maximum - * number of messages per account. + * Build a data set for training an importance classifier. * * @param Account $account + * @param IExtractor $extractor + * @param LoggerInterface $logger + * @param PerformanceLoggerTask|null $perf + * @param bool $shuffle + * @return array|null Returns null if there are not enough messages to train */ - public function train(Account $account, LoggerInterface $logger): void { - $perf = $this->performanceLogger->start('importance classifier training'); + public function buildDataSet( + Account $account, + IExtractor $extractor, + LoggerInterface $logger, + ?PerformanceLoggerTask $perf = null, + bool $shuffle = false, + ): ?array { + $perf ??= $this->performanceLogger->start('build data set for importance classifier training'); + $incomingMailboxes = $this->getIncomingMailboxes($account); $logger->debug('found ' . count($incomingMailboxes) . ' incoming mailbox(es)'); $perf->step('find incoming mailboxes'); @@ -160,30 +186,128 @@ public function train(Account $account, LoggerInterface $logger): void { $logger->debug('found ' . count($messages) . ' messages of which ' . count($importantMessages) . ' are important'); if (count($importantMessages) < self::COLD_START_THRESHOLD) { $logger->info('not enough messages to train a classifier'); - $perf->end(); - return; + return null; } $perf->step('find latest ' . self::MAX_TRAINING_SET_SIZE . ' messages'); - $dataSet = $this->getFeaturesAndImportance($account, $incomingMailboxes, $outgoingMailboxes, $messages); - $perf->step('extract features from messages'); + $dataSet = $this->getFeaturesAndImportance($account, $incomingMailboxes, $outgoingMailboxes, $messages, $extractor); + if ($shuffle) { + shuffle($dataSet); + } + + return $dataSet; + } + + /** + * Train an account's classifier of important messages + * + * Train a classifier based on a user's existing messages to be able to derive + * importance markers for new incoming messages. + * + * To factor in (server-side) filtering into multiple mailboxes, the algorithm + * will not only look for messages in the inbox but also other non-special + * mailboxes. + * + * To prevent memory exhaustion, the process will only load a fixed maximum + * number of messages per account. + * + * @param Account $account + * @param LoggerInterface $logger + * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used. + * @param bool $shuffleDataSet Shuffle the data set before training + * @param bool $persist Persist the trained classifier to use it for message classification + * + * @return ClassifierPipeline|null The validation estimator, persisted estimator (if `$persist` === true) or null in case none was trained + * + * @throws ServiceException + */ + public function train( + Account $account, + LoggerInterface $logger, + ?Closure $estimator = null, + bool $shuffleDataSet = false, + bool $persist = true, + ): ?ClassifierPipeline { + $perf = $this->performanceLogger->start('importance classifier training'); + + $extractor = $this->createExtractor(); + $dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet); + if ($dataSet === null) { + return null; + } + + return $this->trainWithCustomDataSet( + $account, + $logger, + $dataSet, + $extractor, + $estimator, + $perf, + $persist, + ); + } + + /** + * Train a classifier using a custom data set. + * + * @param Account $account + * @param LoggerInterface $logger + * @param array $dataSet Training data set built by buildDataSet() + * @param CompositeExtractor $extractor Extractor used to extract the given data set + * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used. + * @param PerformanceLoggerTask|null $perf Optionally reuse a performance logger task + * @param bool $persist Persist the trained classifier to use it for message classification + * + * @return ClassifierPipeline|null The validation estimator, persisted estimator (if `$persist` === true) or null in case none was trained + * + * @throws ServiceException + */ + private function trainWithCustomDataSet( + Account $account, + LoggerInterface $logger, + array $dataSet, + CompositeExtractor $extractor, + ?Closure $estimator, + ?PerformanceLoggerTask $perf = null, + bool $persist = true, + ): ?ClassifierPipeline { + $perf ??= $this->performanceLogger->start('importance classifier training'); + $estimator ??= self::createDefaultEstimator(...); /** * How many of the most recent messages are excluded from training? */ $validationThreshold = max( 5, - (int)(count($dataSet) * 0.1) + (int)(count($dataSet) * 0.2) ); $validationSet = array_slice($dataSet, 0, $validationThreshold); $trainingSet = array_slice($dataSet, $validationThreshold); - $logger->debug('data set split into ' . count($trainingSet) . ' training and ' . count($validationSet) . ' validation sets with ' . count($trainingSet[0]['features'] ?? []) . ' dimensions'); + + $validationSetImportantCount = 0; + $trainingSetImportantCount = 0; + foreach ($validationSet as $data) { + if ($data['label'] === self::LABEL_IMPORTANT) { + $validationSetImportantCount++; + } + } + foreach ($trainingSet as $data) { + if ($data['label'] === self::LABEL_IMPORTANT) { + $trainingSetImportantCount++; + } + } + + $logger->debug('data set split into ' . count($trainingSet) . ' (' . self::LABEL_IMPORTANT . ': ' . $trainingSetImportantCount . ') training and ' . count($validationSet) . ' (' . self::LABEL_IMPORTANT . ': ' . $validationSetImportantCount . ') validation sets with ' . count($trainingSet[0]['features'] ?? []) . ' dimensions'); + if ($validationSet === [] || $trainingSet === []) { $logger->info('not enough messages to train a classifier'); $perf->end(); - return; + return null; } - $validationEstimator = $this->trainClassifier($trainingSet); + + /** @var Learner&Estimator&Persistable $validationEstimator */ + $validationEstimator = $estimator(); + $this->trainClassifier($validationEstimator, $validationSet); try { $classifier = $this->validateClassifier( $validationEstimator, @@ -196,19 +320,31 @@ public function train(Account $account, LoggerInterface $logger): void { 'exception' => $e, ]); $perf->end(); - return; + return null; } $perf->step('train and validate classifier with training and validation sets'); - $estimator = $this->trainClassifier($dataSet); - $perf->step('train classifier with full data set'); + if (!$persist) { + return new ClassifierPipeline($validationEstimator, $extractor); + } - $classifier->setAccountId($account->getId()); + /** @var Learner&Estimator&Persistable $persistedEstimator */ + $persistedEstimator = $estimator(); + $this->trainClassifier($persistedEstimator, $dataSet); + $perf->step('train classifier with full data set'); $classifier->setDuration($perf->end()); - $this->persistenceService->persist($classifier, $estimator); - $logger->debug("classifier {$classifier->getId()} persisted"); + $classifier->setAccountId($account->getId()); + $classifier->setEstimator(get_class($persistedEstimator)); + $classifier->setPersistenceVersion(PersistenceService::VERSION); + + $this->persistenceService->persist($account, $persistedEstimator, $extractor); + $logger->debug("Classifier for account {$account->getId()} persisted", [ + 'classifier' => $classifier, + ]); + return new ClassifierPipeline($persistedEstimator, $extractor); } + /** * @param Account $account * @@ -259,17 +395,20 @@ private function getOutgoingMailboxes(Account $account): array { private function getFeaturesAndImportance(Account $account, array $incomingMailboxes, array $outgoingMailboxes, - array $messages): array { - $this->extractor->prepare($account, $incomingMailboxes, $outgoingMailboxes, $messages); + array $messages, + IExtractor $extractor): array { + $extractor->prepare($account, $incomingMailboxes, $outgoingMailboxes, $messages); - return array_map(function (Message $message) { + return array_map(static function (Message $message) use ($extractor) { $sender = $message->getFrom()->first(); if ($sender === null) { throw new RuntimeException('This should not happen'); } + $features = $extractor->extract($message); + return [ - 'features' => $this->extractor->extract($message), + 'features' => $features, 'label' => $message->getFlagImportant() ? self::LABEL_IMPORTANT : self::LABEL_NOT_IMPORTANT, 'sender' => $sender->getEmail(), ]; @@ -279,21 +418,38 @@ private function getFeaturesAndImportance(Account $account, /** * @param Account $account * @param Message[] $messages + * @param LoggerInterface $logger * * @return bool[] + * * @throws ServiceException */ - public function classifyImportance(Account $account, array $messages): array { - $estimator = null; + public function classifyImportance(Account $account, + array $messages, + LoggerInterface $logger): array { + $pipeline = null; try { - $estimator = $this->persistenceService->loadLatest($account); + $pipeline = $this->persistenceService->loadLatest($account); } catch (ServiceException $e) { - $this->logger->warning('Failed to load importance classifier: ' . $e->getMessage(), [ + $logger->warning('Failed to load persisted estimator and extractor: ' . $e->getMessage(), [ 'exception' => $e, ]); } - if ($estimator === null) { + // Persistence is disabled on some instances (due to no memory cache being available). + // Try to train a classifier on-the-fly on those instances. + if ($pipeline === null) { + $pipeline = $this->train($account, $logger); + } + + // Can't train pipeline and no persistence available? -> Skip rule based classifier ... + // It won't yield good results. Instead, we have to wait for the user to accumulate more + // emails so that training a classifier succeeds. + if ($pipeline === null && !$this->persistenceService->isAvailable()) { + return []; + } + + if ($pipeline === null) { $predictions = $this->rulesClassifier->classifyImportance( $account, $this->getIncomingMailboxes($account), @@ -309,15 +465,16 @@ public function classifyImportance(Account $account, array $messages): array { }, $messages) ); } - $messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']); + $messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']); $features = $this->getFeaturesAndImportance( $account, $this->getIncomingMailboxes($account), $this->getOutgoingMailboxes($account), - $messagesWithSender + $messagesWithSender, + $pipeline->getExtractor(), ); - $predictions = $estimator->predict( + $predictions = $pipeline->getEstimator()->predict( Unlabeled::build(array_column($features, 'features')) ); return array_combine( @@ -330,22 +487,20 @@ public function classifyImportance(Account $account, array $messages): array { ); } - private function trainClassifier(array $trainingSet): GaussianNB { - $classifier = new GaussianNB(); + private function trainClassifier(Learner $classifier, array $trainingSet): void { $classifier->train(Labeled::build( array_column($trainingSet, 'features'), array_column($trainingSet, 'label') )); - return $classifier; } /** * @param Estimator $estimator * @param array $trainingSet * @param array $validationSet + * @param LoggerInterface $logger * * @return Classifier - * @throws ClassifierTrainingException */ private function validateClassifier(Estimator $estimator, array $trainingSet, diff --git a/lib/Service/Classification/NewMessagesClassifier.php b/lib/Service/Classification/NewMessagesClassifier.php new file mode 100644 index 0000000000..64602f55b4 --- /dev/null +++ b/lib/Service/Classification/NewMessagesClassifier.php @@ -0,0 +1,109 @@ +insertBulk()). + * + * @param Message[] $messages + * @param Mailbox $mailbox + * @param Account $account + * @param Tag $importantTag + * @return void + */ + public function classifyNewMessages( + array $messages, + Mailbox $mailbox, + Account $account, + Tag $importantTag, + ): void { + $allowTagging = $this->preferences->getPreference($account->getUserId(), 'tag-classified-messages'); + if ($allowTagging === 'false') { + return; + } + + foreach (self::EXEMPT_FROM_CLASSIFICATION as $specialUse) { + if ($mailbox->isSpecialUse($specialUse)) { + // Nothing to do then + return; + } + } + + // if this is a message that's been flagged / tagged as important before, we don't want to reclassify it again. + $doNotReclassify = $this->tagMapper->getTaggedMessageIdsForMessages( + $messages, + $account->getUserId(), + Tag::LABEL_IMPORTANT + ); + $messages = array_filter($messages, static function ($message) use ($doNotReclassify) { + return ($message->getFlagImportant() === false || in_array($message->getMessageId(), $doNotReclassify, true)); + }); + + try { + $predictions = $this->classifier->classifyImportance( + $account, + $messages, + $this->logger + ); + + foreach ($messages as $message) { + $this->logger->info("Message {$message->getUid()} ({$message->getPreviewText()}) is " . ($predictions[$message->getUid()] ? 'important' : 'not important')); + if ($predictions[$message->getUid()] ?? false) { + $message->setFlagImportant(true); + $this->mailManager->flagMessage($account, $mailbox->getName(), $message->getUid(), Tag::LABEL_IMPORTANT, true); + $this->mailManager->tagMessage($account, $mailbox->getName(), $message, $importantTag, true); + } + } + } catch (ServiceException $e) { + $this->logger->error('Could not classify incoming message importance: ' . $e->getMessage(), [ + 'exception' => $e, + ]); + } catch (ClientException $e) { + $this->logger->error('Could not persist incoming message importance to IMAP: ' . $e->getMessage(), [ + 'exception' => $e, + ]); + } + } +} diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index 5cd73c09c6..ccfb40b0ff 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -3,258 +3,232 @@ declare(strict_types=1); /** - * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors + * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Service\Classification; use OCA\Mail\Account; -use OCA\Mail\AppInfo\Application; -use OCA\Mail\Db\Classifier; -use OCA\Mail\Db\ClassifierMapper; -use OCA\Mail\Db\MailAccountMapper; use OCA\Mail\Exception\ServiceException; -use OCP\App\IAppManager; -use OCP\AppFramework\Db\DoesNotExistException; -use OCP\AppFramework\Utility\ITimeFactory; -use OCP\Files\IAppData; -use OCP\Files\NotFoundException; -use OCP\Files\NotPermittedException; +use OCA\Mail\Model\ClassifierPipeline; +use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; +use OCP\ICache; use OCP\ICacheFactory; -use OCP\ITempManager; -use Psr\Log\LoggerInterface; -use Rubix\ML\Estimator; +use Psr\Container\ContainerExceptionInterface; +use Psr\Container\ContainerInterface; use Rubix\ML\Learner; use Rubix\ML\Persistable; use Rubix\ML\PersistentModel; -use Rubix\ML\Persisters\Filesystem; +use Rubix\ML\Serializers\RBX; +use Rubix\ML\Transformers\TfIdfTransformer; +use Rubix\ML\Transformers\Transformer; +use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; -use function file_get_contents; -use function file_put_contents; use function get_class; -use function strlen; class PersistenceService { - private const ADD_DATA_FOLDER = 'classifiers'; + // Increment the version when changing the classifier or transformer pipeline + public const VERSION = 1; - /** @var ClassifierMapper */ - private $mapper; - - /** @var IAppData */ - private $appData; - - /** @var ITempManager */ - private $tempManager; - - /** @var ITimeFactory */ - private $timeFactory; - - /** @var IAppManager */ - private $appManager; - - /** @var ICacheFactory */ - private $cacheFactory; - - /** @var LoggerInterface */ - private $logger; - - /** @var MailAccountMapper */ - private $accountMapper; - - public function __construct(ClassifierMapper $mapper, - IAppData $appData, - ITempManager $tempManager, - ITimeFactory $timeFactory, - IAppManager $appManager, - ICacheFactory $cacheFactory, - LoggerInterface $logger, - MailAccountMapper $accountMapper) { - $this->mapper = $mapper; - $this->appData = $appData; - $this->tempManager = $tempManager; - $this->timeFactory = $timeFactory; - $this->appManager = $appManager; - $this->cacheFactory = $cacheFactory; - $this->logger = $logger; - $this->accountMapper = $accountMapper; + public function __construct( + private readonly ICacheFactory $cacheFactory, + private readonly ContainerInterface $container, + ) { } /** - * Persist the classifier data to the database and the estimator to storage + * Persist classifier, estimator and its transformers to the memory cache. * - * @param Classifier $classifier * @param Learner&Persistable $estimator * - * @throws ServiceException + * @throws ServiceException If any serialization fails */ - public function persist(Classifier $classifier, Learner $estimator): void { - /* - * First we have to insert the row to get the unique ID, but disable - * it until the model is persisted as well. Otherwise another process - * might try to load the model in the meantime and run into an error - * due to the missing data in app data. - */ - $classifier->setAppVersion($this->appManager->getAppVersion(Application::APP_ID)); - $classifier->setEstimator(get_class($estimator)); - $classifier->setActive(false); - $classifier->setCreatedAt($this->timeFactory->getTime()); - $this->mapper->insert($classifier); + public function persist( + Account $account, + Learner $estimator, + CompositeExtractor $extractor, + ): void { + $serializedData = []; /* - * Then we serialize the estimator into a temporary file + * First we serialize the estimator */ - $tmpPath = $this->tempManager->getTemporaryFile(); try { - $model = new PersistentModel($estimator, new Filesystem($tmpPath)); + $persister = new RubixMemoryPersister(); + $model = new PersistentModel($estimator, $persister); $model->save(); - $serializedClassifier = file_get_contents($tmpPath); - $this->logger->debug('Serialized classifier written to tmp file (' . strlen($serializedClassifier) . 'B'); + $serializedData[] = $persister->getData(); } catch (RuntimeException $e) { throw new ServiceException('Could not serialize classifier: ' . $e->getMessage(), 0, $e); } /* - * Then we store the serialized model to app data + * Then we serialize the transformer pipeline */ - try { + $transformers = [ + $extractor->getSubjectExtractor()->getWordCountVectorizer(), + $extractor->getSubjectExtractor()->getTfIdf(), + ]; + $serializer = new RBX(); + foreach ($transformers as $transformer) { try { - $folder = $this->appData->getFolder(self::ADD_DATA_FOLDER); - $this->logger->debug('Using existing folder for the serialized classifier'); - } catch (NotFoundException $e) { - $folder = $this->appData->newFolder(self::ADD_DATA_FOLDER); - $this->logger->debug('New folder created for serialized classifiers'); + $persister = new RubixMemoryPersister(); + /** + * This is how to serialize a transformer according to the official docs. + * PersistentModel can only be used on Learners which transformers don't implement. + * + * Ref https://docs.rubixml.com/2.0/model-persistence.html#persisting-transformers + * + * @psalm-suppress InternalMethod + */ + $serializer->serialize($transformer)->saveTo($persister); + $serializedData[] = $persister->getData(); + } catch (RuntimeException $e) { + throw new ServiceException('Could not serialize transformer: ' . $e->getMessage(), 0, $e); } - $file = $folder->newFile((string)$classifier->getId()); - $file->putContent($serializedClassifier); - $this->logger->debug('Serialized classifier written to app data'); - } catch (NotPermittedException $e) { - throw new ServiceException('Could not create classifiers directory: ' . $e->getMessage(), 0, $e); } - /* - * Now we set the model active so it can be used by the next request - */ - $classifier->setActive(true); - $this->mapper->update($classifier); + $this->setCached((string)$account->getId(), $serializedData); } /** - * @param Account $account + * Load the latest estimator and its transformers. * - * @return Estimator|null - * @throws ServiceException + * @throws ServiceException If any deserialization fails */ - public function loadLatest(Account $account): ?Estimator { - try { - $latestModel = $this->mapper->findLatest($account->getId()); - } catch (DoesNotExistException $e) { + public function loadLatest(Account $account): ?ClassifierPipeline { + $cached = $this->getCached((string)$account->getId()); + if ($cached === null) { return null; } - return $this->load($latestModel->getId()); - } - /** - * @param int $id - * - * @return Estimator - * @throws ServiceException - */ - public function load(int $id): Estimator { - $cached = $this->getCached($id); - if ($cached !== null) { - $this->logger->debug("Using cached serialized classifier $id"); - $serialized = $cached; - } else { - $this->logger->debug('Loading serialized classifier from app data'); + $serializedModel = $cached[0]; + $serializedTransformers = array_slice($cached, 1); + try { + $estimator = PersistentModel::load(new RubixMemoryPersister($serializedModel)); + } catch (RuntimeException $e) { + throw new ServiceException( + 'Could not deserialize persisted classifier: ' . $e->getMessage(), + 0, + $e, + ); + } + + $serializer = new RBX(); + $transformers = array_map(function (string $serializedTransformer) use ($serializer) { try { - $modelsFolder = $this->appData->getFolder(self::ADD_DATA_FOLDER); - $modelFile = $modelsFolder->getFile((string)$id); - } catch (NotFoundException $e) { - $this->logger->debug("Could not load classifier $id: " . $e->getMessage()); - throw new ServiceException("Could not load classifier $id: " . $e->getMessage(), 0, $e); + $persister = new RubixMemoryPersister($serializedTransformer); + $transformer = $persister->load()->deserializeWith($serializer); + } catch (RuntimeException $e) { + throw new ServiceException( + 'Could not deserialize persisted transformer of classifier: ' . $e->getMessage(), + 0, + $e, + ); } - try { - $serialized = $modelFile->getContent(); - } catch (NotFoundException|NotPermittedException $e) { - $this->logger->debug("Could not load content for model file with classifier id $id: " . $e->getMessage()); - throw new ServiceException("Could not load content for model file with classifier id $id: " . $e->getMessage(), 0, $e); + if (!($transformer instanceof Transformer)) { + throw new ServiceException(sprintf( + 'Transformer is not an instance of %s: Got %s', + Transformer::class, + get_class($transformer), + )); } - $size = strlen($serialized); - $this->logger->debug("Serialized classifier loaded (size=$size)"); - $this->cache($id, $serialized); + return $transformer; + }, $serializedTransformers); + + $extractor = $this->loadExtractor($transformers); + + return new ClassifierPipeline($estimator, $extractor); + } + + /** + * Load and instantiate extractor based on the given transformers. + * + * @throws ServiceException If the transformers array contains unexpected instances or the composite extractor can't be instantiated + */ + private function loadExtractor(array $transformers): IExtractor { + $wordCountVectorizer = $transformers[0]; + if (!($wordCountVectorizer instanceof WordCountVectorizer)) { + throw new ServiceException(sprintf( + 'Failed to load persisted transformer: Expected %s, got %s', + WordCountVectorizer::class, + get_class($wordCountVectorizer), + )); } - $tmpPath = $this->tempManager->getTemporaryFile(); - file_put_contents($tmpPath, $serialized); + $tfidfTransformer = $transformers[1]; + if (!($tfidfTransformer instanceof TfIdfTransformer)) { + throw new ServiceException(sprintf( + 'Failed to load persisted transformer: Expected %s, got %s', + TfIdfTransformer::class, + get_class($tfidfTransformer), + )); + } try { - $estimator = PersistentModel::load(new Filesystem($tmpPath)); - } catch (RuntimeException $e) { - throw new ServiceException("Could not deserialize persisted classifier $id: " . $e->getMessage(), 0, $e); + /** @var CompositeExtractor $extractor */ + $extractor = $this->container->get(CompositeExtractor::class); + } catch (ContainerExceptionInterface $e) { + throw new ServiceException('Failed to instantiate the composite extractor', 0, $e); } - return $estimator; + $extractor->getSubjectExtractor()->setWordCountVectorizer($wordCountVectorizer); + $extractor->getSubjectExtractor()->setTfidf($tfidfTransformer); + return $extractor; } - public function cleanUp(): void { - $threshold = $this->timeFactory->getTime() - 30 * 24 * 60 * 60; - $totalAccounts = $this->accountMapper->getTotal(); - $classifiers = $this->mapper->findHistoric($threshold, $totalAccounts * 10); - foreach ($classifiers as $classifier) { - try { - $this->deleteModel($classifier->getId()); - $this->mapper->delete($classifier); - } catch (NotPermittedException $e) { - // Log and continue. This is not critical - $this->logger->warning('Could not clean-up old classifier', [ - 'id' => $classifier->getId(), - 'exception' => $e, - ]); - } + private function getCacheInstance(): ?ICache { + if (!$this->isAvailable()) { + return null; } + + $version = self::VERSION; + return $this->cacheFactory->createDistributed("mail-classifier/v$version/"); } /** - * @throws NotPermittedException + * @return string[]|null Array of serialized classifier and transformers */ - private function deleteModel(int $id): void { - $this->logger->debug('Deleting serialized classifier from app data', [ - 'id' => $id, - ]); - try { - $modelsFolder = $this->appData->getFolder(self::ADD_DATA_FOLDER); - $modelFile = $modelsFolder->getFile((string)$id); - $modelFile->delete(); - } catch (NotFoundException $e) { - $this->logger->debug("Classifier model $id does not exist", [ - 'exception' => $e, - ]); + private function getCached(string $id): ?array { + $cache = $this->getCacheInstance(); + if ($cache === null) { + return null; } - } - - private function getCacheKey(int $id): string { - return "mail_classifier_$id"; - } - private function getCached(int $id): ?string { - if (!$this->cacheFactory->isLocalCacheAvailable()) { + $json = $cache->get($id); + if (!is_string($json)) { return null; } - $cache = $this->cacheFactory->createLocal(); - return $cache->get( - $this->getCacheKey($id) - ); + $serializedData = json_decode($json); + return array_map(base64_decode(...), $serializedData); } - private function cache(int $id, string $serialized): void { - if (!$this->cacheFactory->isLocalCacheAvailable()) { + /** + * @param string[] $serializedData Array of serialized classifier and transformers + */ + private function setCached(string $id, array $serializedData): void { + $cache = $this->getCacheInstance(); + if ($cache === null) { return; } - $cache = $this->cacheFactory->createLocal(); - $cache->set($this->getCacheKey($id), $serialized); + + // Serialized data contains binary, non-utf8 data so we encode it as base64 first + $encodedData = array_map(base64_encode(...), $serializedData); + $json = json_encode($encodedData, JSON_THROW_ON_ERROR); + + // Set a ttl of a week because a new model will be generated daily + $cache->set($id, $json, 3600 * 24 * 7); + } + + /** + * Returns true if the persistence layer is available on this Nextcloud server. + */ + public function isAvailable(): bool { + return $this->cacheFactory->isAvailable(); } } diff --git a/lib/Service/Classification/RubixMemoryPersister.php b/lib/Service/Classification/RubixMemoryPersister.php new file mode 100644 index 0000000000..c3abff2463 --- /dev/null +++ b/lib/Service/Classification/RubixMemoryPersister.php @@ -0,0 +1,36 @@ +data; + } + + public function save(Encoding $encoding): void { + $this->data = $encoding->data(); + } + + public function load(): Encoding { + return new Encoding($this->data); + } + + public function __toString() { + return self::class; + } +} diff --git a/lib/Service/CleanupService.php b/lib/Service/CleanupService.php index 65fa421200..74f5c0070a 100644 --- a/lib/Service/CleanupService.php +++ b/lib/Service/CleanupService.php @@ -17,7 +17,6 @@ use OCA\Mail\Db\MessageRetentionMapper; use OCA\Mail\Db\MessageSnoozeMapper; use OCA\Mail\Db\TagMapper; -use OCA\Mail\Service\Classification\PersistenceService; use OCA\Mail\Support\PerformanceLogger; use OCP\AppFramework\Utility\ITimeFactory; use Psr\Log\LoggerInterface; @@ -44,7 +43,6 @@ class CleanupService { private MessageSnoozeMapper $messageSnoozeMapper; - private PersistenceService $classifierPersistenceService; private ITimeFactory $timeFactory; public function __construct(MailAccountMapper $mailAccountMapper, @@ -55,7 +53,6 @@ public function __construct(MailAccountMapper $mailAccountMapper, TagMapper $tagMapper, MessageRetentionMapper $messageRetentionMapper, MessageSnoozeMapper $messageSnoozeMapper, - PersistenceService $classifierPersistenceService, ITimeFactory $timeFactory) { $this->aliasMapper = $aliasMapper; $this->mailboxMapper = $mailboxMapper; @@ -64,7 +61,6 @@ public function __construct(MailAccountMapper $mailAccountMapper, $this->tagMapper = $tagMapper; $this->messageRetentionMapper = $messageRetentionMapper; $this->messageSnoozeMapper = $messageSnoozeMapper; - $this->classifierPersistenceService = $classifierPersistenceService; $this->mailAccountMapper = $mailAccountMapper; $this->timeFactory = $timeFactory; } @@ -92,8 +88,6 @@ public function cleanUp(LoggerInterface $logger): void { $task->step('delete expired messages'); $this->messageSnoozeMapper->deleteOrphans(); $task->step('delete orphan snoozes'); - $this->classifierPersistenceService->cleanUp(); - $task->step('delete orphan classifiers'); $task->end(); } } diff --git a/lib/Service/Sync/ImapToDbSynchronizer.php b/lib/Service/Sync/ImapToDbSynchronizer.php index 9c52aad632..dd4af7add2 100644 --- a/lib/Service/Sync/ImapToDbSynchronizer.php +++ b/lib/Service/Sync/ImapToDbSynchronizer.php @@ -18,6 +18,8 @@ use OCA\Mail\Db\Mailbox; use OCA\Mail\Db\MailboxMapper; use OCA\Mail\Db\MessageMapper as DatabaseMessageMapper; +use OCA\Mail\Db\Tag; +use OCA\Mail\Db\TagMapper; use OCA\Mail\Events\NewMessagesSynchronized; use OCA\Mail\Events\SynchronizationEvent; use OCA\Mail\Exception\ClientException; @@ -32,7 +34,9 @@ use OCA\Mail\IMAP\Sync\Request; use OCA\Mail\IMAP\Sync\Synchronizer; use OCA\Mail\Model\IMAPMessage; +use OCA\Mail\Service\Classification\NewMessagesClassifier; use OCA\Mail\Support\PerformanceLogger; +use OCP\AppFramework\Db\DoesNotExistException; use OCP\EventDispatcher\IEventDispatcher; use Psr\Log\LoggerInterface; use Throwable; @@ -72,15 +76,21 @@ class ImapToDbSynchronizer { /** @var IMailManager */ private $mailManager; + private TagMapper $tagMapper; + private NewMessagesClassifier $newMessagesClassifier; + public function __construct(DatabaseMessageMapper $dbMapper, IMAPClientFactory $clientFactory, ImapMessageMapper $imapMapper, MailboxMapper $mailboxMapper, + DatabaseMessageMapper $messageMapper, Synchronizer $synchronizer, IEventDispatcher $dispatcher, PerformanceLogger $performanceLogger, LoggerInterface $logger, - IMailManager $mailManager) { + IMailManager $mailManager, + TagMapper $tagMapper, + NewMessagesClassifier $newMessagesClassifier) { $this->dbMapper = $dbMapper; $this->clientFactory = $clientFactory; $this->imapMapper = $imapMapper; @@ -90,6 +100,8 @@ public function __construct(DatabaseMessageMapper $dbMapper, $this->performanceLogger = $performanceLogger; $this->logger = $logger; $this->mailManager = $mailManager; + $this->tagMapper = $tagMapper; + $this->newMessagesClassifier = $newMessagesClassifier; } /** @@ -105,9 +117,9 @@ public function syncAccount(Account $account, $snoozeMailboxId = $account->getMailAccount()->getSnoozeMailboxId(); $sentMailboxId = $account->getMailAccount()->getSentMailboxId(); $trashRetentionDays = $account->getMailAccount()->getTrashRetentionDays(); - + $client = $this->clientFactory->getClient($account); - + foreach ($this->mailboxMapper->findAll($account) as $mailbox) { $syncTrash = $trashMailboxId === $mailbox->getId() && $trashRetentionDays !== null; $syncSnooze = $snoozeMailboxId === $mailbox->getId(); @@ -131,7 +143,7 @@ public function syncAccount(Account $account, $rebuildThreads = true; } } - + $client->logout(); $this->dispatcher->dispatchTyped( @@ -413,6 +425,15 @@ private function runPartialSync( }); } + $importantTag = null; + try { + $importantTag = $this->tagMapper->getTagByImapLabel(Tag::LABEL_IMPORTANT, $account->getUserId()); + } catch (DoesNotExistException $e) { + $this->logger->error('Could not find important tag for ' . $account->getUserId() . ' ' . $e->getMessage(), [ + 'exception' => $e, + ]); + } + foreach (array_chunk($newMessages, 500) as $chunk) { $dbMessages = array_map(static function (IMAPMessage $imapMessage) use ($mailbox, $account) { return $imapMessage->toDbMessage($mailbox->getId(), $account->getMailAccount()); @@ -420,6 +441,15 @@ private function runPartialSync( $this->dbMapper->insertBulk($account, ...$dbMessages); + if ($importantTag) { + $this->newMessagesClassifier->classifyNewMessages( + $dbMessages, + $mailbox, + $account, + $importantTag, + ); + } + $this->dispatcher->dispatch( NewMessagesSynchronized::class, new NewMessagesSynchronized($account, $mailbox, $dbMessages)