diff --git a/appinfo/info.xml b/appinfo/info.xml
index 6bb3d39258..90e007313b 100644
--- a/appinfo/info.xml
+++ b/appinfo/info.xml
@@ -34,7 +34,7 @@ The rating depends on the installed text processing backend. See [the rating ove
Learn more about the Nextcloud Ethical AI Rating [in our blog](https://nextcloud.com/blog/nextcloud-ethical-ai-rating/).
]]>
- 4.2.0-alpha.0
+ 4.2.0-alpha.1
agpl
Christoph Wurst
GretaD
@@ -90,6 +90,7 @@ Learn more about the Nextcloud Ethical AI Rating [in our blog](https://nextcloud
OCA\Mail\Command\TrainAccount
OCA\Mail\Command\UpdateAccount
OCA\Mail\Command\UpdateSystemAutoresponders
+ OCA\Mail\Command\RunMetaEstimator
OCA\Mail\Settings\AdminSettings
diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php
index 01bea26944..216d489af4 100644
--- a/lib/AppInfo/Application.php
+++ b/lib/AppInfo/Application.php
@@ -45,7 +45,6 @@
use OCA\Mail\Listener\MessageCacheUpdaterListener;
use OCA\Mail\Listener\MessageKnownSinceListener;
use OCA\Mail\Listener\MoveJunkListener;
-use OCA\Mail\Listener\NewMessageClassificationListener;
use OCA\Mail\Listener\NewMessagesNotifier;
use OCA\Mail\Listener\OauthTokenRefreshListener;
use OCA\Mail\Listener\OptionalIndicesListener;
@@ -130,7 +129,6 @@ public function register(IRegistrationContext $context): void {
$context->registerEventListener(MessageDeletedEvent::class, MessageCacheUpdaterListener::class);
$context->registerEventListener(MessageSentEvent::class, AddressCollectionListener::class);
$context->registerEventListener(MessageSentEvent::class, InteractionListener::class);
- $context->registerEventListener(NewMessagesSynchronized::class, NewMessageClassificationListener::class);
$context->registerEventListener(NewMessagesSynchronized::class, MessageKnownSinceListener::class);
$context->registerEventListener(NewMessagesSynchronized::class, NewMessagesNotifier::class);
$context->registerEventListener(SynchronizationEvent::class, AccountSynchronizedThreadUpdaterListener::class);
diff --git a/lib/BackgroundJob/TrainImportanceClassifierJob.php b/lib/BackgroundJob/TrainImportanceClassifierJob.php
index d6482cbe62..0ec467a1a1 100644
--- a/lib/BackgroundJob/TrainImportanceClassifierJob.php
+++ b/lib/BackgroundJob/TrainImportanceClassifierJob.php
@@ -69,10 +69,7 @@ protected function run($argument) {
}
try {
- $this->classifier->train(
- $account,
- $this->logger
- );
+ $this->classifier->train($account, $this->logger);
} catch (Throwable $e) {
$this->logger->error('Cron importance classifier training failed: ' . $e->getMessage(), [
'exception' => $e,
diff --git a/lib/Command/PredictImportance.php b/lib/Command/PredictImportance.php
index 386dc413ea..fcaa458ec2 100644
--- a/lib/Command/PredictImportance.php
+++ b/lib/Command/PredictImportance.php
@@ -13,6 +13,7 @@
use OCA\Mail\Db\Message;
use OCA\Mail\Service\AccountService;
use OCA\Mail\Service\Classification\ImportanceClassifier;
+use OCA\Mail\Support\ConsoleLoggerDecorator;
use OCP\AppFramework\Db\DoesNotExistException;
use OCP\IConfig;
use Psr\Log\LoggerInterface;
@@ -25,6 +26,7 @@
class PredictImportance extends Command {
public const ARGUMENT_ACCOUNT_ID = 'account-id';
public const ARGUMENT_SENDER = 'sender';
+ public const ARGUMENT_SUBJECT = 'subject';
private AccountService $accountService;
private ImportanceClassifier $classifier;
@@ -43,26 +45,27 @@ public function __construct(AccountService $service,
$this->config = $config;
}
- /**
- * @return void
- */
- protected function configure() {
+ protected function configure(): void {
$this->setName('mail:predict-importance');
$this->setDescription('Predict importance of an incoming message');
$this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED);
$this->addArgument(self::ARGUMENT_SENDER, InputArgument::REQUIRED);
+ $this->addArgument(self::ARGUMENT_SUBJECT, InputArgument::OPTIONAL);
}
- public function isEnabled() {
+ public function isEnabled(): bool {
return $this->config->getSystemValueBool('debug');
}
- /**
- * @return int
- */
protected function execute(InputInterface $input, OutputInterface $output): int {
$accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID);
$sender = $input->getArgument(self::ARGUMENT_SENDER);
+ $subject = $input->getArgument(self::ARGUMENT_SUBJECT) ?? '';
+
+ $consoleLogger = new ConsoleLoggerDecorator(
+ $this->logger,
+ $output
+ );
try {
$account = $this->accountService->findById($accountId);
@@ -73,9 +76,11 @@ protected function execute(InputInterface $input, OutputInterface $output): int
$fakeMessage = new Message();
$fakeMessage->setUid(0);
$fakeMessage->setFrom(AddressList::parse("Name <$sender>"));
+ $fakeMessage->setSubject($subject);
[$prediction] = $this->classifier->classifyImportance(
$account,
- [$fakeMessage]
+ [$fakeMessage],
+ $consoleLogger
);
if ($prediction) {
$output->writeln('Message is important');
diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php
new file mode 100644
index 0000000000..b23a55ad4a
--- /dev/null
+++ b/lib/Command/RunMetaEstimator.php
@@ -0,0 +1,117 @@
+accountService = $accountService;
+ $this->logger = $logger;
+ $this->classifier = $classifier;
+ $this->config = $config;
+ }
+
+ protected function configure(): void {
+ $this->setName('mail:account:run-meta-estimator');
+ $this->setDescription('Run the meta estimator for an account');
+ $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED);
+ $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training');
+ }
+
+ public function isEnabled(): bool {
+ return $this->config->getSystemValueBool('debug');
+ }
+
+ protected function execute(InputInterface $input, OutputInterface $output): int {
+ $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID);
+ $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE);
+
+ try {
+ $account = $this->accountService->findById($accountId);
+ } catch (DoesNotExistException $e) {
+ $output->writeln("Account $accountId does not exist");
+ return 1;
+ }
+
+ $consoleLogger = new ConsoleLoggerDecorator(
+ $this->logger,
+ $output
+ );
+
+ $estimator = static function () use ($consoleLogger) {
+ $params = [
+ [5, 10, 15, 20, 25, 30, 35, 40], // Neighbors
+ [true, false], // Weighted?
+ [new Euclidean(), new Manhattan(), new Jaccard()], // Kernel
+ ];
+
+ $estimator = new GridSearch(
+ KNearestNeighbors::class,
+ $params,
+ new FBeta(),
+ new KFold(5)
+ );
+ $estimator->setLogger($consoleLogger);
+ $estimator->setBackend(new Amp());
+ return $estimator;
+ };
+
+ $pipeline = $this->classifier->train(
+ $account,
+ $consoleLogger,
+ $estimator,
+ $shuffle,
+ false,
+ );
+
+ /** @var GridSearch $metaEstimator */
+ $metaEstimator = $pipeline?->getEstimator();
+ if ($metaEstimator !== null) {
+ $output->writeln("Best estimator: {$metaEstimator->base()}");
+ }
+
+ $mbs = (int)(memory_get_peak_usage() / 1024 / 1024);
+ $output->writeln('' . $mbs . 'MB of memory used');
+ return 0;
+ }
+}
diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php
index a2d2b09cc4..19fef4ad0d 100644
--- a/lib/Command/TrainAccount.php
+++ b/lib/Command/TrainAccount.php
@@ -3,7 +3,7 @@
declare(strict_types=1);
/**
- * SPDX-FileCopyrightText: 2019 Nextcloud GmbH and Nextcloud contributors
+ * SPDX-FileCopyrightText: 2019-2024 Nextcloud GmbH and Nextcloud contributors
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
@@ -23,6 +23,9 @@
class TrainAccount extends Command {
public const ARGUMENT_ACCOUNT_ID = 'account-id';
+ public const ARGUMENT_SHUFFLE = 'shuffle';
+ public const ARGUMENT_DRY_RUN = 'dry-run';
+ public const ARGUMENT_FORCE = 'force';
private AccountService $accountService;
private ImportanceClassifier $classifier;
@@ -41,20 +44,30 @@ public function __construct(AccountService $service,
$this->classificationSettingsService = $classificationSettingsService;
}
- /**
- * @return void
- */
- protected function configure() {
+ protected function configure(): void {
$this->setName('mail:account:train');
$this->setDescription('Train the classifier of new messages');
$this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED);
+ $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training');
+ $this->addOption(
+ self::ARGUMENT_DRY_RUN,
+ null,
+ null,
+ 'Don\'t persist classifier after training'
+ );
+ $this->addOption(
+ self::ARGUMENT_FORCE,
+ null,
+ null,
+ 'Train an estimator even if the classification is disabled by the user'
+ );
}
- /**
- * @return int
- */
protected function execute(InputInterface $input, OutputInterface $output): int {
$accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID);
+ $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE);
+ $dryRun = (bool)$input->getOption(self::ARGUMENT_DRY_RUN);
+ $force = (bool)$input->getOption(self::ARGUMENT_FORCE);
try {
$account = $this->accountService->findById($accountId);
@@ -62,7 +75,8 @@ protected function execute(InputInterface $input, OutputInterface $output): int
$output->writeln("account $accountId does not exist");
return 1;
}
- if (!$this->classificationSettingsService->isClassificationEnabled($account->getUserId())) {
+
+ if (!$force && !$this->classificationSettingsService->isClassificationEnabled($account->getUserId())) {
$output->writeln("classification is turned off for account $accountId");
return 2;
}
@@ -71,9 +85,13 @@ protected function execute(InputInterface $input, OutputInterface $output): int
$this->logger,
$output
);
+
$this->classifier->train(
$account,
- $consoleLogger
+ $consoleLogger,
+ null,
+ $shuffle,
+ !$dryRun
);
$mbs = (int)(memory_get_peak_usage() / 1024 / 1024);
diff --git a/lib/Db/Classifier.php b/lib/Db/Classifier.php
deleted file mode 100644
index dadbfaa4f6..0000000000
--- a/lib/Db/Classifier.php
+++ /dev/null
@@ -1,103 +0,0 @@
-addType('accountId', 'integer');
- $this->addType('type', 'string');
- $this->addType('appVersion', 'string');
- $this->addType('trainingSetSize', 'integer');
- $this->addType('validationSetSize', 'integer');
- $this->addType('recallImportant', 'float');
- $this->addType('precisionImportant', 'float');
- $this->addType('f1ScoreImportant', 'float');
- $this->addType('duration', 'integer');
- $this->addType('active', 'boolean');
- $this->addType('createdAt', 'integer');
- }
-}
diff --git a/lib/Db/ClassifierMapper.php b/lib/Db/ClassifierMapper.php
deleted file mode 100644
index 946b70f8b0..0000000000
--- a/lib/Db/ClassifierMapper.php
+++ /dev/null
@@ -1,57 +0,0 @@
-
- */
-class ClassifierMapper extends QBMapper {
- public function __construct(IDBConnection $db) {
- parent::__construct($db, 'mail_classifiers');
- }
-
- /**
- * @param int $id
- *
- * @return Classifier
- * @throws DoesNotExistException
- */
- public function findLatest(int $id): Classifier {
- $qb = $this->db->getQueryBuilder();
-
- $select = $qb->select('*')
- ->from($this->getTableName())
- ->where(
- $qb->expr()->eq('account_id', $qb->createNamedParameter($id, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT),
- $qb->expr()->eq('active', $qb->createNamedParameter(true, IQueryBuilder::PARAM_BOOL), IQueryBuilder::PARAM_BOOL)
- )
- ->orderBy('created_at', 'desc')
- ->setMaxResults(1);
-
- return $this->findEntity($select);
- }
-
- public function findHistoric(int $threshold, int $limit) {
- $qb = $this->db->getQueryBuilder();
- $select = $qb->select('*')
- ->from($this->getTableName())
- ->where(
- $qb->expr()->lte('created_at', $qb->createNamedParameter($threshold, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT),
- )
- ->orderBy('created_at', 'asc')
- ->setMaxResults($limit);
- return $this->findEntities($select);
- }
-}
diff --git a/lib/Listener/NewMessageClassificationListener.php b/lib/Listener/NewMessageClassificationListener.php
deleted file mode 100644
index 01c0ca5f4f..0000000000
--- a/lib/Listener/NewMessageClassificationListener.php
+++ /dev/null
@@ -1,119 +0,0 @@
-
- */
-class NewMessageClassificationListener implements IEventListener {
- private const EXEMPT_FROM_CLASSIFICATION = [
- Horde_Imap_Client::SPECIALUSE_ARCHIVE,
- Horde_Imap_Client::SPECIALUSE_DRAFTS,
- Horde_Imap_Client::SPECIALUSE_JUNK,
- Horde_Imap_Client::SPECIALUSE_SENT,
- Horde_Imap_Client::SPECIALUSE_TRASH,
- ];
-
- /** @var ImportanceClassifier */
- private $classifier;
-
- /** @var TagMapper */
- private $tagMapper;
-
- /** @var LoggerInterface */
- private $logger;
-
- /** @var IMailManager */
- private $mailManager;
-
- private ClassificationSettingsService $classificationSettingsService;
-
- public function __construct(ImportanceClassifier $classifier,
- TagMapper $tagMapper,
- LoggerInterface $logger,
- IMailManager $mailManager,
- ClassificationSettingsService $classificationSettingsService) {
- $this->classifier = $classifier;
- $this->logger = $logger;
- $this->tagMapper = $tagMapper;
- $this->mailManager = $mailManager;
- $this->classificationSettingsService = $classificationSettingsService;
- }
-
- public function handle(Event $event): void {
- if (!($event instanceof NewMessagesSynchronized)) {
- return;
- }
-
- if (!$this->classificationSettingsService->isClassificationEnabled($event->getAccount()->getUserId())) {
- return;
- }
-
- foreach (self::EXEMPT_FROM_CLASSIFICATION as $specialUse) {
- if ($event->getMailbox()->isSpecialUse($specialUse)) {
- // Nothing to do then
- return;
- }
- }
-
- $messages = $event->getMessages();
-
- // if this is a message that's been flagged / tagged as important before, we don't want to reclassify it again.
- $doNotReclassify = $this->tagMapper->getTaggedMessageIdsForMessages(
- $event->getMessages(),
- $event->getAccount()->getUserId(),
- Tag::LABEL_IMPORTANT
- );
- $messages = array_filter($messages, static function ($message) use ($doNotReclassify) {
- return ($message->getFlagImportant() === false || in_array($message->getMessageId(), $doNotReclassify, true));
- });
-
- try {
- $important = $this->tagMapper->getTagByImapLabel(Tag::LABEL_IMPORTANT, $event->getAccount()->getUserId());
- } catch (DoesNotExistException $e) {
- // just in case - if we get here, the tag is missing
- $this->logger->error('Could not find important tag for ' . $event->getAccount()->getUserId() . ' ' . $e->getMessage(), [
- 'exception' => $e,
- ]);
- return;
- }
-
- try {
- $predictions = $this->classifier->classifyImportance(
- $event->getAccount(),
- $messages
- );
-
- foreach ($event->getMessages() as $message) {
- if ($predictions[$message->getUid()] ?? false) {
- $this->mailManager->flagMessage($event->getAccount(), $event->getMailbox()->getName(), $message->getUid(), Tag::LABEL_IMPORTANT, true);
- $this->mailManager->tagMessage($event->getAccount(), $event->getMailbox()->getName(), $message, $important, true);
- }
- }
- } catch (ServiceException $e) {
- $this->logger->error('Could not classify incoming message importance: ' . $e->getMessage(), [
- 'exception' => $e,
- ]);
- }
- }
-}
diff --git a/lib/Migration/Version4100Date20241021091352.php b/lib/Migration/Version4100Date20241021091352.php
new file mode 100644
index 0000000000..ccdf8be264
--- /dev/null
+++ b/lib/Migration/Version4100Date20241021091352.php
@@ -0,0 +1,29 @@
+dropTable('mail_classifiers');
+ return $schema;
+ }
+}
diff --git a/lib/Model/Classifier.php b/lib/Model/Classifier.php
new file mode 100644
index 0000000000..df4d21eeb1
--- /dev/null
+++ b/lib/Model/Classifier.php
@@ -0,0 +1,134 @@
+accountId;
+ }
+
+ public function setAccountId(int $accountId): void {
+ $this->accountId = $accountId;
+ }
+
+ public function getType(): string {
+ return $this->type;
+ }
+
+ public function setType(string $type): void {
+ $this->type = $type;
+ }
+
+ public function getEstimator(): string {
+ return $this->estimator;
+ }
+
+ public function setEstimator(string $estimator): void {
+ $this->estimator = $estimator;
+ }
+
+ public function getPersistenceVersion(): int {
+ return $this->persistenceVersion;
+ }
+
+ public function setPersistenceVersion(int $persistenceVersion): void {
+ $this->persistenceVersion = $persistenceVersion;
+ }
+
+ public function getTrainingSetSize(): int {
+ return $this->trainingSetSize;
+ }
+
+ public function setTrainingSetSize(int $trainingSetSize): void {
+ $this->trainingSetSize = $trainingSetSize;
+ }
+
+ public function getValidationSetSize(): int {
+ return $this->validationSetSize;
+ }
+
+ public function setValidationSetSize(int $validationSetSize): void {
+ $this->validationSetSize = $validationSetSize;
+ }
+
+ public function getRecallImportant(): float {
+ return $this->recallImportant;
+ }
+
+ public function setRecallImportant(float $recallImportant): void {
+ $this->recallImportant = $recallImportant;
+ }
+
+ public function getPrecisionImportant(): float {
+ return $this->precisionImportant;
+ }
+
+ public function setPrecisionImportant(float $precisionImportant): void {
+ $this->precisionImportant = $precisionImportant;
+ }
+
+ public function getF1ScoreImportant(): float {
+ return $this->f1ScoreImportant;
+ }
+
+ public function setF1ScoreImportant(float $f1ScoreImportant): void {
+ $this->f1ScoreImportant = $f1ScoreImportant;
+ }
+
+ public function getDuration(): int {
+ return $this->duration;
+ }
+
+ public function setDuration(int $duration): void {
+ $this->duration = $duration;
+ }
+
+ public function getCreatedAt(): int {
+ return $this->createdAt;
+ }
+
+ public function setCreatedAt(int $createdAt): void {
+ $this->createdAt = $createdAt;
+ }
+
+ #[ReturnTypeWillChange]
+ public function jsonSerialize() {
+ return [
+ 'accountId' => $this->accountId,
+ 'type' => $this->type,
+ 'estimator' => $this->estimator,
+ 'persistenceVersion' => $this->persistenceVersion,
+ 'trainingSetSize' => $this->trainingSetSize,
+ 'validationSetSize' => $this->validationSetSize,
+ 'recallImportant' => $this->recallImportant,
+ 'precisionImportant' => $this->precisionImportant,
+ 'f1ScoreImportant' => $this->f1ScoreImportant,
+ 'duration' => $this->duration,
+ 'createdAt' => $this->createdAt,
+ ];
+ }
+}
diff --git a/lib/Model/ClassifierPipeline.php b/lib/Model/ClassifierPipeline.php
new file mode 100644
index 0000000000..f58d40ba89
--- /dev/null
+++ b/lib/Model/ClassifierPipeline.php
@@ -0,0 +1,29 @@
+estimator;
+ }
+
+ public function getExtractor(): IExtractor {
+ return $this->extractor;
+ }
+}
diff --git a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php
index eefadd92ba..54469bc220 100644
--- a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php
+++ b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php
@@ -3,7 +3,7 @@
declare(strict_types=1);
/**
- * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors
+ * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
@@ -17,18 +17,25 @@
* Combines a set of DI'ed extractors so they can be used as one class
*/
class CompositeExtractor implements IExtractor {
+ private readonly SubjectExtractor $subjectExtractor;
+
/** @var IExtractor[] */
- private $extractors;
+ private readonly array $extractors;
- public function __construct(ImportantMessagesExtractor $ex1,
+ public function __construct(
+ ImportantMessagesExtractor $ex1,
ReadMessagesExtractor $ex2,
RepliedMessagesExtractor $ex3,
- SentMessagesExtractor $ex4) {
+ SentMessagesExtractor $ex4,
+ SubjectExtractor $ex5,
+ ) {
+ $this->subjectExtractor = $ex5;
$this->extractors = [
$ex1,
$ex2,
$ex3,
$ex4,
+ $ex5,
];
}
@@ -46,4 +53,8 @@ public function extract(Message $message): array {
return $extractor->extract($message);
}, $this->extractors);
}
+
+ public function getSubjectExtractor(): SubjectExtractor {
+ return $this->subjectExtractor;
+ }
}
diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php
new file mode 100644
index 0000000000..9118c054b7
--- /dev/null
+++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php
@@ -0,0 +1,110 @@
+wordCountVectorizer = new WordCountVectorizer(self::MAX_VOCABULARY_SIZE);
+ $this->tfidf = new TfIdfTransformer();
+ }
+
+ public function getWordCountVectorizer(): WordCountVectorizer {
+ return $this->wordCountVectorizer;
+ }
+
+ public function setWordCountVectorizer(WordCountVectorizer $wordCountVectorizer): void {
+ $this->wordCountVectorizer = $wordCountVectorizer;
+ $this->limitFeatureSize();
+ }
+
+ public function getTfIdf(): TfIdfTransformer {
+ return $this->tfidf;
+ }
+
+ public function setTfidf(TfIdfTransformer $tfidf): void {
+ $this->tfidf = $tfidf;
+ }
+
+ public function prepare(Account $account, array $incomingMailboxes, array $outgoingMailboxes, array $messages): void {
+ /** @var array> $data */
+ $data = array_map(static function (Message $message) {
+ return [
+ 'text' => $message->getSubject() ?? '',
+ 'label' => $message->getFlagImportant()
+ ? ImportanceClassifier::LABEL_IMPORTANT
+ : ImportanceClassifier::LABEL_NOT_IMPORTANT,
+ ];
+ }, $messages);
+
+ // Fit transformers
+ Labeled::build(array_column($data, 'text'), array_column($data, 'label'))
+ ->apply(new MultibyteTextNormalizer())
+ ->apply($this->wordCountVectorizer)
+ ->apply($this->tfidf);
+
+ $this->limitFeatureSize();
+ }
+
+ public function extract(Message $message): array {
+ $sender = $message->getFrom()->first();
+ if ($sender === null) {
+ throw new RuntimeException('This should not happen');
+ }
+
+ // Build training data set
+ $trainText = $message->getSubject() ?? '';
+
+ $trainDataSet = Unlabeled::build([[$trainText]])
+ ->apply(new MultibyteTextNormalizer())
+ ->apply($this->wordCountVectorizer)
+ ->apply($this->tfidf);
+
+ // Use zeroed vector if no features could be extracted
+ if ($trainDataSet->numFeatures() === 0) {
+ $textFeatures = array_fill(0, $this->max, 0);
+ } else {
+ $textFeatures = $trainDataSet->sample(0);
+ }
+
+ return $textFeatures;
+ }
+
+ /**
+ * Limit feature vector length to actual size of vocabulary.
+ */
+ private function limitFeatureSize(): void {
+ $vocabularies = $this->wordCountVectorizer->vocabularies();
+ if (!isset($vocabularies[0])) {
+ // Should not happen but better safe than sorry
+ return;
+ }
+
+ $this->max = count($vocabularies[0]);
+ }
+}
diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php
index 434d0c5bcc..0368617308 100644
--- a/lib/Service/Classification/ImportanceClassifier.php
+++ b/lib/Service/Classification/ImportanceClassifier.php
@@ -3,30 +3,39 @@
declare(strict_types=1);
/**
- * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors
+ * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
namespace OCA\Mail\Service\Classification;
+use Closure;
use Horde_Imap_Client;
use OCA\Mail\Account;
-use OCA\Mail\Db\Classifier;
use OCA\Mail\Db\Mailbox;
use OCA\Mail\Db\MailboxMapper;
use OCA\Mail\Db\Message;
use OCA\Mail\Db\MessageMapper;
use OCA\Mail\Exception\ClassifierTrainingException;
use OCA\Mail\Exception\ServiceException;
+use OCA\Mail\Model\Classifier;
+use OCA\Mail\Model\ClassifierPipeline;
use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor;
+use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor;
use OCA\Mail\Support\PerformanceLogger;
+use OCA\Mail\Support\PerformanceLoggerTask;
use OCP\AppFramework\Db\DoesNotExistException;
+use Psr\Container\ContainerExceptionInterface;
+use Psr\Container\ContainerInterface;
use Psr\Log\LoggerInterface;
-use Rubix\ML\Classifiers\GaussianNB;
+use Rubix\ML\Classifiers\KNearestNeighbors;
use Rubix\ML\CrossValidation\Reports\MulticlassBreakdown;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Estimator;
+use Rubix\ML\Kernels\Distance\Manhattan;
+use Rubix\ML\Learner;
+use Rubix\ML\Persistable;
use RuntimeException;
use function array_column;
use function array_combine;
@@ -64,12 +73,12 @@ class ImportanceClassifier {
/**
* @var string label for data sets that are classified as important
*/
- private const LABEL_IMPORTANT = 'i';
+ public const LABEL_IMPORTANT = 'i';
/**
* @var string label for data sets that are classified as not important
*/
- private const LABEL_NOT_IMPORTANT = 'ni';
+ public const LABEL_NOT_IMPORTANT = 'ni';
/**
* The minimum number of important messages. Without those the unsupervised
@@ -81,7 +90,7 @@ class ImportanceClassifier {
/**
* The maximum number of data sets to train the classifier with
*/
- private const MAX_TRAINING_SET_SIZE = 1000;
+ private const MAX_TRAINING_SET_SIZE = 300;
/** @var MailboxMapper */
private $mailboxMapper;
@@ -89,9 +98,6 @@ class ImportanceClassifier {
/** @var MessageMapper */
private $messageMapper;
- /** @var CompositeExtractor */
- private $extractor;
-
/** @var PersistenceService */
private $persistenceService;
@@ -101,22 +107,40 @@ class ImportanceClassifier {
/** @var ImportanceRulesClassifier */
private $rulesClassifier;
- private LoggerInterface $logger;
+ private ContainerInterface $container;
public function __construct(MailboxMapper $mailboxMapper,
MessageMapper $messageMapper,
- CompositeExtractor $extractor,
PersistenceService $persistenceService,
PerformanceLogger $performanceLogger,
ImportanceRulesClassifier $rulesClassifier,
- LoggerInterface $logger) {
+ ContainerInterface $container) {
$this->mailboxMapper = $mailboxMapper;
$this->messageMapper = $messageMapper;
- $this->extractor = $extractor;
$this->persistenceService = $persistenceService;
$this->performanceLogger = $performanceLogger;
$this->rulesClassifier = $rulesClassifier;
- $this->logger = $logger;
+ $this->container = $container;
+ }
+
+ private static function createDefaultEstimator(): KNearestNeighbors {
+ // A meta estimator was trained on the same data multiple times to average out the
+ // variance of the trained model.
+ // Parameters were chosen from the best configuration across 100 runs.
+ // Both variance (spread) and f1 score were considered.
+ // Note: Lower k values yield slightly higher f1 scores but show higher variances.
+ return new KNearestNeighbors(15, true, new Manhattan());
+ }
+
+ /**
+ * @throws ServiceException If the extractor is not available
+ */
+ private function createExtractor(): CompositeExtractor {
+ try {
+ return $this->container->get(CompositeExtractor::class);
+ } catch (ContainerExceptionInterface $e) {
+ throw new ServiceException('Default extractor is not available', 0, $e);
+ }
}
private function filterMessageHasSenderEmail(Message $message): bool {
@@ -124,22 +148,24 @@ private function filterMessageHasSenderEmail(Message $message): bool {
}
/**
- * Train an account's classifier of important messages
- *
- * Train a classifier based on a user's existing messages to be able to derive
- * importance markers for new incoming messages.
- *
- * To factor in (server-side) filtering into multiple mailboxes, the algorithm
- * will not only look for messages in the inbox but also other non-special
- * mailboxes.
- *
- * To prevent memory exhaustion, the process will only load a fixed maximum
- * number of messages per account.
+ * Build a data set for training an importance classifier.
*
* @param Account $account
+ * @param IExtractor $extractor
+ * @param LoggerInterface $logger
+ * @param PerformanceLoggerTask|null $perf
+ * @param bool $shuffle
+ * @return array|null Returns null if there are not enough messages to train
*/
- public function train(Account $account, LoggerInterface $logger): void {
- $perf = $this->performanceLogger->start('importance classifier training');
+ public function buildDataSet(
+ Account $account,
+ IExtractor $extractor,
+ LoggerInterface $logger,
+ ?PerformanceLoggerTask $perf = null,
+ bool $shuffle = false,
+ ): ?array {
+ $perf ??= $this->performanceLogger->start('build data set for importance classifier training');
+
$incomingMailboxes = $this->getIncomingMailboxes($account);
$logger->debug('found ' . count($incomingMailboxes) . ' incoming mailbox(es)');
$perf->step('find incoming mailboxes');
@@ -160,30 +186,128 @@ public function train(Account $account, LoggerInterface $logger): void {
$logger->debug('found ' . count($messages) . ' messages of which ' . count($importantMessages) . ' are important');
if (count($importantMessages) < self::COLD_START_THRESHOLD) {
$logger->info('not enough messages to train a classifier');
- $perf->end();
- return;
+ return null;
}
$perf->step('find latest ' . self::MAX_TRAINING_SET_SIZE . ' messages');
- $dataSet = $this->getFeaturesAndImportance($account, $incomingMailboxes, $outgoingMailboxes, $messages);
- $perf->step('extract features from messages');
+ $dataSet = $this->getFeaturesAndImportance($account, $incomingMailboxes, $outgoingMailboxes, $messages, $extractor);
+ if ($shuffle) {
+ shuffle($dataSet);
+ }
+
+ return $dataSet;
+ }
+
+ /**
+ * Train an account's classifier of important messages
+ *
+ * Train a classifier based on a user's existing messages to be able to derive
+ * importance markers for new incoming messages.
+ *
+ * To factor in (server-side) filtering into multiple mailboxes, the algorithm
+ * will not only look for messages in the inbox but also other non-special
+ * mailboxes.
+ *
+ * To prevent memory exhaustion, the process will only load a fixed maximum
+ * number of messages per account.
+ *
+ * @param Account $account
+ * @param LoggerInterface $logger
+ * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used.
+ * @param bool $shuffleDataSet Shuffle the data set before training
+ * @param bool $persist Persist the trained classifier to use it for message classification
+ *
+ * @return ClassifierPipeline|null The validation estimator, persisted estimator (if `$persist` === true) or null in case none was trained
+ *
+ * @throws ServiceException
+ */
+ public function train(
+ Account $account,
+ LoggerInterface $logger,
+ ?Closure $estimator = null,
+ bool $shuffleDataSet = false,
+ bool $persist = true,
+ ): ?ClassifierPipeline {
+ $perf = $this->performanceLogger->start('importance classifier training');
+
+ $extractor = $this->createExtractor();
+ $dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet);
+ if ($dataSet === null) {
+ return null;
+ }
+
+ return $this->trainWithCustomDataSet(
+ $account,
+ $logger,
+ $dataSet,
+ $extractor,
+ $estimator,
+ $perf,
+ $persist,
+ );
+ }
+
+ /**
+ * Train a classifier using a custom data set.
+ *
+ * @param Account $account
+ * @param LoggerInterface $logger
+ * @param array $dataSet Training data set built by buildDataSet()
+ * @param CompositeExtractor $extractor Extractor used to extract the given data set
+ * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used.
+ * @param PerformanceLoggerTask|null $perf Optionally reuse a performance logger task
+ * @param bool $persist Persist the trained classifier to use it for message classification
+ *
+ * @return ClassifierPipeline|null The validation estimator, persisted estimator (if `$persist` === true) or null in case none was trained
+ *
+ * @throws ServiceException
+ */
+ private function trainWithCustomDataSet(
+ Account $account,
+ LoggerInterface $logger,
+ array $dataSet,
+ CompositeExtractor $extractor,
+ ?Closure $estimator,
+ ?PerformanceLoggerTask $perf = null,
+ bool $persist = true,
+ ): ?ClassifierPipeline {
+ $perf ??= $this->performanceLogger->start('importance classifier training');
+ $estimator ??= self::createDefaultEstimator(...);
/**
* How many of the most recent messages are excluded from training?
*/
$validationThreshold = max(
5,
- (int)(count($dataSet) * 0.1)
+ (int)(count($dataSet) * 0.2)
);
$validationSet = array_slice($dataSet, 0, $validationThreshold);
$trainingSet = array_slice($dataSet, $validationThreshold);
- $logger->debug('data set split into ' . count($trainingSet) . ' training and ' . count($validationSet) . ' validation sets with ' . count($trainingSet[0]['features'] ?? []) . ' dimensions');
+
+ $validationSetImportantCount = 0;
+ $trainingSetImportantCount = 0;
+ foreach ($validationSet as $data) {
+ if ($data['label'] === self::LABEL_IMPORTANT) {
+ $validationSetImportantCount++;
+ }
+ }
+ foreach ($trainingSet as $data) {
+ if ($data['label'] === self::LABEL_IMPORTANT) {
+ $trainingSetImportantCount++;
+ }
+ }
+
+ $logger->debug('data set split into ' . count($trainingSet) . ' (' . self::LABEL_IMPORTANT . ': ' . $trainingSetImportantCount . ') training and ' . count($validationSet) . ' (' . self::LABEL_IMPORTANT . ': ' . $validationSetImportantCount . ') validation sets with ' . count($trainingSet[0]['features'] ?? []) . ' dimensions');
+
if ($validationSet === [] || $trainingSet === []) {
$logger->info('not enough messages to train a classifier');
$perf->end();
- return;
+ return null;
}
- $validationEstimator = $this->trainClassifier($trainingSet);
+
+ /** @var Learner&Estimator&Persistable $validationEstimator */
+ $validationEstimator = $estimator();
+ $this->trainClassifier($validationEstimator, $validationSet);
try {
$classifier = $this->validateClassifier(
$validationEstimator,
@@ -196,19 +320,31 @@ public function train(Account $account, LoggerInterface $logger): void {
'exception' => $e,
]);
$perf->end();
- return;
+ return null;
}
$perf->step('train and validate classifier with training and validation sets');
- $estimator = $this->trainClassifier($dataSet);
- $perf->step('train classifier with full data set');
+ if (!$persist) {
+ return new ClassifierPipeline($validationEstimator, $extractor);
+ }
- $classifier->setAccountId($account->getId());
+ /** @var Learner&Estimator&Persistable $persistedEstimator */
+ $persistedEstimator = $estimator();
+ $this->trainClassifier($persistedEstimator, $dataSet);
+ $perf->step('train classifier with full data set');
$classifier->setDuration($perf->end());
- $this->persistenceService->persist($classifier, $estimator);
- $logger->debug("classifier {$classifier->getId()} persisted");
+ $classifier->setAccountId($account->getId());
+ $classifier->setEstimator(get_class($persistedEstimator));
+ $classifier->setPersistenceVersion(PersistenceService::VERSION);
+
+ $this->persistenceService->persist($account, $persistedEstimator, $extractor);
+ $logger->debug("Classifier for account {$account->getId()} persisted", [
+ 'classifier' => $classifier,
+ ]);
+ return new ClassifierPipeline($persistedEstimator, $extractor);
}
+
/**
* @param Account $account
*
@@ -259,17 +395,20 @@ private function getOutgoingMailboxes(Account $account): array {
private function getFeaturesAndImportance(Account $account,
array $incomingMailboxes,
array $outgoingMailboxes,
- array $messages): array {
- $this->extractor->prepare($account, $incomingMailboxes, $outgoingMailboxes, $messages);
+ array $messages,
+ IExtractor $extractor): array {
+ $extractor->prepare($account, $incomingMailboxes, $outgoingMailboxes, $messages);
- return array_map(function (Message $message) {
+ return array_map(static function (Message $message) use ($extractor) {
$sender = $message->getFrom()->first();
if ($sender === null) {
throw new RuntimeException('This should not happen');
}
+ $features = $extractor->extract($message);
+
return [
- 'features' => $this->extractor->extract($message),
+ 'features' => $features,
'label' => $message->getFlagImportant() ? self::LABEL_IMPORTANT : self::LABEL_NOT_IMPORTANT,
'sender' => $sender->getEmail(),
];
@@ -279,21 +418,38 @@ private function getFeaturesAndImportance(Account $account,
/**
* @param Account $account
* @param Message[] $messages
+ * @param LoggerInterface $logger
*
* @return bool[]
+ *
* @throws ServiceException
*/
- public function classifyImportance(Account $account, array $messages): array {
- $estimator = null;
+ public function classifyImportance(Account $account,
+ array $messages,
+ LoggerInterface $logger): array {
+ $pipeline = null;
try {
- $estimator = $this->persistenceService->loadLatest($account);
+ $pipeline = $this->persistenceService->loadLatest($account);
} catch (ServiceException $e) {
- $this->logger->warning('Failed to load importance classifier: ' . $e->getMessage(), [
+ $logger->warning('Failed to load persisted estimator and extractor: ' . $e->getMessage(), [
'exception' => $e,
]);
}
- if ($estimator === null) {
+ // Persistence is disabled on some instances (due to no memory cache being available).
+ // Try to train a classifier on-the-fly on those instances.
+ if ($pipeline === null) {
+ $pipeline = $this->train($account, $logger);
+ }
+
+ // Can't train pipeline and no persistence available? -> Skip rule based classifier ...
+ // It won't yield good results. Instead, we have to wait for the user to accumulate more
+ // emails so that training a classifier succeeds.
+ if ($pipeline === null && !$this->persistenceService->isAvailable()) {
+ return [];
+ }
+
+ if ($pipeline === null) {
$predictions = $this->rulesClassifier->classifyImportance(
$account,
$this->getIncomingMailboxes($account),
@@ -309,15 +465,16 @@ public function classifyImportance(Account $account, array $messages): array {
}, $messages)
);
}
- $messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']);
+ $messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']);
$features = $this->getFeaturesAndImportance(
$account,
$this->getIncomingMailboxes($account),
$this->getOutgoingMailboxes($account),
- $messagesWithSender
+ $messagesWithSender,
+ $pipeline->getExtractor(),
);
- $predictions = $estimator->predict(
+ $predictions = $pipeline->getEstimator()->predict(
Unlabeled::build(array_column($features, 'features'))
);
return array_combine(
@@ -330,22 +487,20 @@ public function classifyImportance(Account $account, array $messages): array {
);
}
- private function trainClassifier(array $trainingSet): GaussianNB {
- $classifier = new GaussianNB();
+ private function trainClassifier(Learner $classifier, array $trainingSet): void {
$classifier->train(Labeled::build(
array_column($trainingSet, 'features'),
array_column($trainingSet, 'label')
));
- return $classifier;
}
/**
* @param Estimator $estimator
* @param array $trainingSet
* @param array $validationSet
+ * @param LoggerInterface $logger
*
* @return Classifier
- * @throws ClassifierTrainingException
*/
private function validateClassifier(Estimator $estimator,
array $trainingSet,
diff --git a/lib/Service/Classification/NewMessagesClassifier.php b/lib/Service/Classification/NewMessagesClassifier.php
new file mode 100644
index 0000000000..64602f55b4
--- /dev/null
+++ b/lib/Service/Classification/NewMessagesClassifier.php
@@ -0,0 +1,109 @@
+insertBulk()).
+ *
+ * @param Message[] $messages
+ * @param Mailbox $mailbox
+ * @param Account $account
+ * @param Tag $importantTag
+ * @return void
+ */
+ public function classifyNewMessages(
+ array $messages,
+ Mailbox $mailbox,
+ Account $account,
+ Tag $importantTag,
+ ): void {
+ $allowTagging = $this->preferences->getPreference($account->getUserId(), 'tag-classified-messages');
+ if ($allowTagging === 'false') {
+ return;
+ }
+
+ foreach (self::EXEMPT_FROM_CLASSIFICATION as $specialUse) {
+ if ($mailbox->isSpecialUse($specialUse)) {
+ // Nothing to do then
+ return;
+ }
+ }
+
+ // if this is a message that's been flagged / tagged as important before, we don't want to reclassify it again.
+ $doNotReclassify = $this->tagMapper->getTaggedMessageIdsForMessages(
+ $messages,
+ $account->getUserId(),
+ Tag::LABEL_IMPORTANT
+ );
+ $messages = array_filter($messages, static function ($message) use ($doNotReclassify) {
+ return ($message->getFlagImportant() === false || in_array($message->getMessageId(), $doNotReclassify, true));
+ });
+
+ try {
+ $predictions = $this->classifier->classifyImportance(
+ $account,
+ $messages,
+ $this->logger
+ );
+
+ foreach ($messages as $message) {
+ $this->logger->info("Message {$message->getUid()} ({$message->getPreviewText()}) is " . ($predictions[$message->getUid()] ? 'important' : 'not important'));
+ if ($predictions[$message->getUid()] ?? false) {
+ $message->setFlagImportant(true);
+ $this->mailManager->flagMessage($account, $mailbox->getName(), $message->getUid(), Tag::LABEL_IMPORTANT, true);
+ $this->mailManager->tagMessage($account, $mailbox->getName(), $message, $importantTag, true);
+ }
+ }
+ } catch (ServiceException $e) {
+ $this->logger->error('Could not classify incoming message importance: ' . $e->getMessage(), [
+ 'exception' => $e,
+ ]);
+ } catch (ClientException $e) {
+ $this->logger->error('Could not persist incoming message importance to IMAP: ' . $e->getMessage(), [
+ 'exception' => $e,
+ ]);
+ }
+ }
+}
diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php
index 5cd73c09c6..ccfb40b0ff 100644
--- a/lib/Service/Classification/PersistenceService.php
+++ b/lib/Service/Classification/PersistenceService.php
@@ -3,258 +3,232 @@
declare(strict_types=1);
/**
- * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors
+ * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
namespace OCA\Mail\Service\Classification;
use OCA\Mail\Account;
-use OCA\Mail\AppInfo\Application;
-use OCA\Mail\Db\Classifier;
-use OCA\Mail\Db\ClassifierMapper;
-use OCA\Mail\Db\MailAccountMapper;
use OCA\Mail\Exception\ServiceException;
-use OCP\App\IAppManager;
-use OCP\AppFramework\Db\DoesNotExistException;
-use OCP\AppFramework\Utility\ITimeFactory;
-use OCP\Files\IAppData;
-use OCP\Files\NotFoundException;
-use OCP\Files\NotPermittedException;
+use OCA\Mail\Model\ClassifierPipeline;
+use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor;
+use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor;
+use OCP\ICache;
use OCP\ICacheFactory;
-use OCP\ITempManager;
-use Psr\Log\LoggerInterface;
-use Rubix\ML\Estimator;
+use Psr\Container\ContainerExceptionInterface;
+use Psr\Container\ContainerInterface;
use Rubix\ML\Learner;
use Rubix\ML\Persistable;
use Rubix\ML\PersistentModel;
-use Rubix\ML\Persisters\Filesystem;
+use Rubix\ML\Serializers\RBX;
+use Rubix\ML\Transformers\TfIdfTransformer;
+use Rubix\ML\Transformers\Transformer;
+use Rubix\ML\Transformers\WordCountVectorizer;
use RuntimeException;
-use function file_get_contents;
-use function file_put_contents;
use function get_class;
-use function strlen;
class PersistenceService {
- private const ADD_DATA_FOLDER = 'classifiers';
+ // Increment the version when changing the classifier or transformer pipeline
+ public const VERSION = 1;
- /** @var ClassifierMapper */
- private $mapper;
-
- /** @var IAppData */
- private $appData;
-
- /** @var ITempManager */
- private $tempManager;
-
- /** @var ITimeFactory */
- private $timeFactory;
-
- /** @var IAppManager */
- private $appManager;
-
- /** @var ICacheFactory */
- private $cacheFactory;
-
- /** @var LoggerInterface */
- private $logger;
-
- /** @var MailAccountMapper */
- private $accountMapper;
-
- public function __construct(ClassifierMapper $mapper,
- IAppData $appData,
- ITempManager $tempManager,
- ITimeFactory $timeFactory,
- IAppManager $appManager,
- ICacheFactory $cacheFactory,
- LoggerInterface $logger,
- MailAccountMapper $accountMapper) {
- $this->mapper = $mapper;
- $this->appData = $appData;
- $this->tempManager = $tempManager;
- $this->timeFactory = $timeFactory;
- $this->appManager = $appManager;
- $this->cacheFactory = $cacheFactory;
- $this->logger = $logger;
- $this->accountMapper = $accountMapper;
+ public function __construct(
+ private readonly ICacheFactory $cacheFactory,
+ private readonly ContainerInterface $container,
+ ) {
}
/**
- * Persist the classifier data to the database and the estimator to storage
+ * Persist classifier, estimator and its transformers to the memory cache.
*
- * @param Classifier $classifier
* @param Learner&Persistable $estimator
*
- * @throws ServiceException
+ * @throws ServiceException If any serialization fails
*/
- public function persist(Classifier $classifier, Learner $estimator): void {
- /*
- * First we have to insert the row to get the unique ID, but disable
- * it until the model is persisted as well. Otherwise another process
- * might try to load the model in the meantime and run into an error
- * due to the missing data in app data.
- */
- $classifier->setAppVersion($this->appManager->getAppVersion(Application::APP_ID));
- $classifier->setEstimator(get_class($estimator));
- $classifier->setActive(false);
- $classifier->setCreatedAt($this->timeFactory->getTime());
- $this->mapper->insert($classifier);
+ public function persist(
+ Account $account,
+ Learner $estimator,
+ CompositeExtractor $extractor,
+ ): void {
+ $serializedData = [];
/*
- * Then we serialize the estimator into a temporary file
+ * First we serialize the estimator
*/
- $tmpPath = $this->tempManager->getTemporaryFile();
try {
- $model = new PersistentModel($estimator, new Filesystem($tmpPath));
+ $persister = new RubixMemoryPersister();
+ $model = new PersistentModel($estimator, $persister);
$model->save();
- $serializedClassifier = file_get_contents($tmpPath);
- $this->logger->debug('Serialized classifier written to tmp file (' . strlen($serializedClassifier) . 'B');
+ $serializedData[] = $persister->getData();
} catch (RuntimeException $e) {
throw new ServiceException('Could not serialize classifier: ' . $e->getMessage(), 0, $e);
}
/*
- * Then we store the serialized model to app data
+ * Then we serialize the transformer pipeline
*/
- try {
+ $transformers = [
+ $extractor->getSubjectExtractor()->getWordCountVectorizer(),
+ $extractor->getSubjectExtractor()->getTfIdf(),
+ ];
+ $serializer = new RBX();
+ foreach ($transformers as $transformer) {
try {
- $folder = $this->appData->getFolder(self::ADD_DATA_FOLDER);
- $this->logger->debug('Using existing folder for the serialized classifier');
- } catch (NotFoundException $e) {
- $folder = $this->appData->newFolder(self::ADD_DATA_FOLDER);
- $this->logger->debug('New folder created for serialized classifiers');
+ $persister = new RubixMemoryPersister();
+ /**
+ * This is how to serialize a transformer according to the official docs.
+ * PersistentModel can only be used on Learners which transformers don't implement.
+ *
+ * Ref https://docs.rubixml.com/2.0/model-persistence.html#persisting-transformers
+ *
+ * @psalm-suppress InternalMethod
+ */
+ $serializer->serialize($transformer)->saveTo($persister);
+ $serializedData[] = $persister->getData();
+ } catch (RuntimeException $e) {
+ throw new ServiceException('Could not serialize transformer: ' . $e->getMessage(), 0, $e);
}
- $file = $folder->newFile((string)$classifier->getId());
- $file->putContent($serializedClassifier);
- $this->logger->debug('Serialized classifier written to app data');
- } catch (NotPermittedException $e) {
- throw new ServiceException('Could not create classifiers directory: ' . $e->getMessage(), 0, $e);
}
- /*
- * Now we set the model active so it can be used by the next request
- */
- $classifier->setActive(true);
- $this->mapper->update($classifier);
+ $this->setCached((string)$account->getId(), $serializedData);
}
/**
- * @param Account $account
+ * Load the latest estimator and its transformers.
*
- * @return Estimator|null
- * @throws ServiceException
+ * @throws ServiceException If any deserialization fails
*/
- public function loadLatest(Account $account): ?Estimator {
- try {
- $latestModel = $this->mapper->findLatest($account->getId());
- } catch (DoesNotExistException $e) {
+ public function loadLatest(Account $account): ?ClassifierPipeline {
+ $cached = $this->getCached((string)$account->getId());
+ if ($cached === null) {
return null;
}
- return $this->load($latestModel->getId());
- }
- /**
- * @param int $id
- *
- * @return Estimator
- * @throws ServiceException
- */
- public function load(int $id): Estimator {
- $cached = $this->getCached($id);
- if ($cached !== null) {
- $this->logger->debug("Using cached serialized classifier $id");
- $serialized = $cached;
- } else {
- $this->logger->debug('Loading serialized classifier from app data');
+ $serializedModel = $cached[0];
+ $serializedTransformers = array_slice($cached, 1);
+ try {
+ $estimator = PersistentModel::load(new RubixMemoryPersister($serializedModel));
+ } catch (RuntimeException $e) {
+ throw new ServiceException(
+ 'Could not deserialize persisted classifier: ' . $e->getMessage(),
+ 0,
+ $e,
+ );
+ }
+
+ $serializer = new RBX();
+ $transformers = array_map(function (string $serializedTransformer) use ($serializer) {
try {
- $modelsFolder = $this->appData->getFolder(self::ADD_DATA_FOLDER);
- $modelFile = $modelsFolder->getFile((string)$id);
- } catch (NotFoundException $e) {
- $this->logger->debug("Could not load classifier $id: " . $e->getMessage());
- throw new ServiceException("Could not load classifier $id: " . $e->getMessage(), 0, $e);
+ $persister = new RubixMemoryPersister($serializedTransformer);
+ $transformer = $persister->load()->deserializeWith($serializer);
+ } catch (RuntimeException $e) {
+ throw new ServiceException(
+ 'Could not deserialize persisted transformer of classifier: ' . $e->getMessage(),
+ 0,
+ $e,
+ );
}
- try {
- $serialized = $modelFile->getContent();
- } catch (NotFoundException|NotPermittedException $e) {
- $this->logger->debug("Could not load content for model file with classifier id $id: " . $e->getMessage());
- throw new ServiceException("Could not load content for model file with classifier id $id: " . $e->getMessage(), 0, $e);
+ if (!($transformer instanceof Transformer)) {
+ throw new ServiceException(sprintf(
+ 'Transformer is not an instance of %s: Got %s',
+ Transformer::class,
+ get_class($transformer),
+ ));
}
- $size = strlen($serialized);
- $this->logger->debug("Serialized classifier loaded (size=$size)");
- $this->cache($id, $serialized);
+ return $transformer;
+ }, $serializedTransformers);
+
+ $extractor = $this->loadExtractor($transformers);
+
+ return new ClassifierPipeline($estimator, $extractor);
+ }
+
+ /**
+ * Load and instantiate extractor based on the given transformers.
+ *
+ * @throws ServiceException If the transformers array contains unexpected instances or the composite extractor can't be instantiated
+ */
+ private function loadExtractor(array $transformers): IExtractor {
+ $wordCountVectorizer = $transformers[0];
+ if (!($wordCountVectorizer instanceof WordCountVectorizer)) {
+ throw new ServiceException(sprintf(
+ 'Failed to load persisted transformer: Expected %s, got %s',
+ WordCountVectorizer::class,
+ get_class($wordCountVectorizer),
+ ));
}
- $tmpPath = $this->tempManager->getTemporaryFile();
- file_put_contents($tmpPath, $serialized);
+ $tfidfTransformer = $transformers[1];
+ if (!($tfidfTransformer instanceof TfIdfTransformer)) {
+ throw new ServiceException(sprintf(
+ 'Failed to load persisted transformer: Expected %s, got %s',
+ TfIdfTransformer::class,
+ get_class($tfidfTransformer),
+ ));
+ }
try {
- $estimator = PersistentModel::load(new Filesystem($tmpPath));
- } catch (RuntimeException $e) {
- throw new ServiceException("Could not deserialize persisted classifier $id: " . $e->getMessage(), 0, $e);
+ /** @var CompositeExtractor $extractor */
+ $extractor = $this->container->get(CompositeExtractor::class);
+ } catch (ContainerExceptionInterface $e) {
+ throw new ServiceException('Failed to instantiate the composite extractor', 0, $e);
}
- return $estimator;
+ $extractor->getSubjectExtractor()->setWordCountVectorizer($wordCountVectorizer);
+ $extractor->getSubjectExtractor()->setTfidf($tfidfTransformer);
+ return $extractor;
}
- public function cleanUp(): void {
- $threshold = $this->timeFactory->getTime() - 30 * 24 * 60 * 60;
- $totalAccounts = $this->accountMapper->getTotal();
- $classifiers = $this->mapper->findHistoric($threshold, $totalAccounts * 10);
- foreach ($classifiers as $classifier) {
- try {
- $this->deleteModel($classifier->getId());
- $this->mapper->delete($classifier);
- } catch (NotPermittedException $e) {
- // Log and continue. This is not critical
- $this->logger->warning('Could not clean-up old classifier', [
- 'id' => $classifier->getId(),
- 'exception' => $e,
- ]);
- }
+ private function getCacheInstance(): ?ICache {
+ if (!$this->isAvailable()) {
+ return null;
}
+
+ $version = self::VERSION;
+ return $this->cacheFactory->createDistributed("mail-classifier/v$version/");
}
/**
- * @throws NotPermittedException
+ * @return string[]|null Array of serialized classifier and transformers
*/
- private function deleteModel(int $id): void {
- $this->logger->debug('Deleting serialized classifier from app data', [
- 'id' => $id,
- ]);
- try {
- $modelsFolder = $this->appData->getFolder(self::ADD_DATA_FOLDER);
- $modelFile = $modelsFolder->getFile((string)$id);
- $modelFile->delete();
- } catch (NotFoundException $e) {
- $this->logger->debug("Classifier model $id does not exist", [
- 'exception' => $e,
- ]);
+ private function getCached(string $id): ?array {
+ $cache = $this->getCacheInstance();
+ if ($cache === null) {
+ return null;
}
- }
-
- private function getCacheKey(int $id): string {
- return "mail_classifier_$id";
- }
- private function getCached(int $id): ?string {
- if (!$this->cacheFactory->isLocalCacheAvailable()) {
+ $json = $cache->get($id);
+ if (!is_string($json)) {
return null;
}
- $cache = $this->cacheFactory->createLocal();
- return $cache->get(
- $this->getCacheKey($id)
- );
+ $serializedData = json_decode($json);
+ return array_map(base64_decode(...), $serializedData);
}
- private function cache(int $id, string $serialized): void {
- if (!$this->cacheFactory->isLocalCacheAvailable()) {
+ /**
+ * @param string[] $serializedData Array of serialized classifier and transformers
+ */
+ private function setCached(string $id, array $serializedData): void {
+ $cache = $this->getCacheInstance();
+ if ($cache === null) {
return;
}
- $cache = $this->cacheFactory->createLocal();
- $cache->set($this->getCacheKey($id), $serialized);
+
+ // Serialized data contains binary, non-utf8 data so we encode it as base64 first
+ $encodedData = array_map(base64_encode(...), $serializedData);
+ $json = json_encode($encodedData, JSON_THROW_ON_ERROR);
+
+ // Set a ttl of a week because a new model will be generated daily
+ $cache->set($id, $json, 3600 * 24 * 7);
+ }
+
+ /**
+ * Returns true if the persistence layer is available on this Nextcloud server.
+ */
+ public function isAvailable(): bool {
+ return $this->cacheFactory->isAvailable();
}
}
diff --git a/lib/Service/Classification/RubixMemoryPersister.php b/lib/Service/Classification/RubixMemoryPersister.php
new file mode 100644
index 0000000000..c3abff2463
--- /dev/null
+++ b/lib/Service/Classification/RubixMemoryPersister.php
@@ -0,0 +1,36 @@
+data;
+ }
+
+ public function save(Encoding $encoding): void {
+ $this->data = $encoding->data();
+ }
+
+ public function load(): Encoding {
+ return new Encoding($this->data);
+ }
+
+ public function __toString() {
+ return self::class;
+ }
+}
diff --git a/lib/Service/CleanupService.php b/lib/Service/CleanupService.php
index 65fa421200..74f5c0070a 100644
--- a/lib/Service/CleanupService.php
+++ b/lib/Service/CleanupService.php
@@ -17,7 +17,6 @@
use OCA\Mail\Db\MessageRetentionMapper;
use OCA\Mail\Db\MessageSnoozeMapper;
use OCA\Mail\Db\TagMapper;
-use OCA\Mail\Service\Classification\PersistenceService;
use OCA\Mail\Support\PerformanceLogger;
use OCP\AppFramework\Utility\ITimeFactory;
use Psr\Log\LoggerInterface;
@@ -44,7 +43,6 @@ class CleanupService {
private MessageSnoozeMapper $messageSnoozeMapper;
- private PersistenceService $classifierPersistenceService;
private ITimeFactory $timeFactory;
public function __construct(MailAccountMapper $mailAccountMapper,
@@ -55,7 +53,6 @@ public function __construct(MailAccountMapper $mailAccountMapper,
TagMapper $tagMapper,
MessageRetentionMapper $messageRetentionMapper,
MessageSnoozeMapper $messageSnoozeMapper,
- PersistenceService $classifierPersistenceService,
ITimeFactory $timeFactory) {
$this->aliasMapper = $aliasMapper;
$this->mailboxMapper = $mailboxMapper;
@@ -64,7 +61,6 @@ public function __construct(MailAccountMapper $mailAccountMapper,
$this->tagMapper = $tagMapper;
$this->messageRetentionMapper = $messageRetentionMapper;
$this->messageSnoozeMapper = $messageSnoozeMapper;
- $this->classifierPersistenceService = $classifierPersistenceService;
$this->mailAccountMapper = $mailAccountMapper;
$this->timeFactory = $timeFactory;
}
@@ -92,8 +88,6 @@ public function cleanUp(LoggerInterface $logger): void {
$task->step('delete expired messages');
$this->messageSnoozeMapper->deleteOrphans();
$task->step('delete orphan snoozes');
- $this->classifierPersistenceService->cleanUp();
- $task->step('delete orphan classifiers');
$task->end();
}
}
diff --git a/lib/Service/Sync/ImapToDbSynchronizer.php b/lib/Service/Sync/ImapToDbSynchronizer.php
index 9c52aad632..dd4af7add2 100644
--- a/lib/Service/Sync/ImapToDbSynchronizer.php
+++ b/lib/Service/Sync/ImapToDbSynchronizer.php
@@ -18,6 +18,8 @@
use OCA\Mail\Db\Mailbox;
use OCA\Mail\Db\MailboxMapper;
use OCA\Mail\Db\MessageMapper as DatabaseMessageMapper;
+use OCA\Mail\Db\Tag;
+use OCA\Mail\Db\TagMapper;
use OCA\Mail\Events\NewMessagesSynchronized;
use OCA\Mail\Events\SynchronizationEvent;
use OCA\Mail\Exception\ClientException;
@@ -32,7 +34,9 @@
use OCA\Mail\IMAP\Sync\Request;
use OCA\Mail\IMAP\Sync\Synchronizer;
use OCA\Mail\Model\IMAPMessage;
+use OCA\Mail\Service\Classification\NewMessagesClassifier;
use OCA\Mail\Support\PerformanceLogger;
+use OCP\AppFramework\Db\DoesNotExistException;
use OCP\EventDispatcher\IEventDispatcher;
use Psr\Log\LoggerInterface;
use Throwable;
@@ -72,15 +76,21 @@ class ImapToDbSynchronizer {
/** @var IMailManager */
private $mailManager;
+ private TagMapper $tagMapper;
+ private NewMessagesClassifier $newMessagesClassifier;
+
public function __construct(DatabaseMessageMapper $dbMapper,
IMAPClientFactory $clientFactory,
ImapMessageMapper $imapMapper,
MailboxMapper $mailboxMapper,
+ DatabaseMessageMapper $messageMapper,
Synchronizer $synchronizer,
IEventDispatcher $dispatcher,
PerformanceLogger $performanceLogger,
LoggerInterface $logger,
- IMailManager $mailManager) {
+ IMailManager $mailManager,
+ TagMapper $tagMapper,
+ NewMessagesClassifier $newMessagesClassifier) {
$this->dbMapper = $dbMapper;
$this->clientFactory = $clientFactory;
$this->imapMapper = $imapMapper;
@@ -90,6 +100,8 @@ public function __construct(DatabaseMessageMapper $dbMapper,
$this->performanceLogger = $performanceLogger;
$this->logger = $logger;
$this->mailManager = $mailManager;
+ $this->tagMapper = $tagMapper;
+ $this->newMessagesClassifier = $newMessagesClassifier;
}
/**
@@ -105,9 +117,9 @@ public function syncAccount(Account $account,
$snoozeMailboxId = $account->getMailAccount()->getSnoozeMailboxId();
$sentMailboxId = $account->getMailAccount()->getSentMailboxId();
$trashRetentionDays = $account->getMailAccount()->getTrashRetentionDays();
-
+
$client = $this->clientFactory->getClient($account);
-
+
foreach ($this->mailboxMapper->findAll($account) as $mailbox) {
$syncTrash = $trashMailboxId === $mailbox->getId() && $trashRetentionDays !== null;
$syncSnooze = $snoozeMailboxId === $mailbox->getId();
@@ -131,7 +143,7 @@ public function syncAccount(Account $account,
$rebuildThreads = true;
}
}
-
+
$client->logout();
$this->dispatcher->dispatchTyped(
@@ -413,6 +425,15 @@ private function runPartialSync(
});
}
+ $importantTag = null;
+ try {
+ $importantTag = $this->tagMapper->getTagByImapLabel(Tag::LABEL_IMPORTANT, $account->getUserId());
+ } catch (DoesNotExistException $e) {
+ $this->logger->error('Could not find important tag for ' . $account->getUserId() . ' ' . $e->getMessage(), [
+ 'exception' => $e,
+ ]);
+ }
+
foreach (array_chunk($newMessages, 500) as $chunk) {
$dbMessages = array_map(static function (IMAPMessage $imapMessage) use ($mailbox, $account) {
return $imapMessage->toDbMessage($mailbox->getId(), $account->getMailAccount());
@@ -420,6 +441,15 @@ private function runPartialSync(
$this->dbMapper->insertBulk($account, ...$dbMessages);
+ if ($importantTag) {
+ $this->newMessagesClassifier->classifyNewMessages(
+ $dbMessages,
+ $mailbox,
+ $account,
+ $importantTag,
+ );
+ }
+
$this->dispatcher->dispatch(
NewMessagesSynchronized::class,
new NewMessagesSynchronized($account, $mailbox, $dbMessages)