diff --git a/lib/Db/Classifier.php b/lib/Db/Classifier.php index 46f064cff1..70f945a5e8 100644 --- a/lib/Db/Classifier.php +++ b/lib/Db/Classifier.php @@ -64,9 +64,6 @@ * @method int getCreatedAt() * @method void setCreatedAt(int $createdAt) * - * @method int getTransformerCount() - * @method void setTransformerCount(int $transformerCount) - * * @method string|null getTransformers() * @method void setTransformers(string|null $transformers) */ @@ -109,12 +106,6 @@ class Classifier extends Entity { /** @var int */ protected $createdAt; - /** @var int */ - protected $transformerCount; - - /** @var string */ - protected $transformers; - public function __construct() { $this->addType('accountId', 'int'); $this->addType('type', 'string'); @@ -127,7 +118,5 @@ public function __construct() { $this->addType('duration', 'int'); $this->addType('active', 'boolean'); $this->addType('createdAt', 'int'); - $this->addType('transformerCount', 'int'); - $this->addType('transformers', 'string'); } } diff --git a/lib/Migration/Version3100Date20230324113141.php b/lib/Migration/Version3100Date20230324113141.php deleted file mode 100644 index 2181e24a73..0000000000 --- a/lib/Migration/Version3100Date20230324113141.php +++ /dev/null @@ -1,56 +0,0 @@ - - * - * @author Richard Steinmetz - * - * @license GNU AGPL version 3 or any later version - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation, either version 3 of the - * License, or (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - * - */ - -namespace OCA\Mail\Migration; - -use Closure; -use OCP\DB\ISchemaWrapper; -use OCP\DB\Types; -use OCP\Migration\IOutput; -use OCP\Migration\SimpleMigrationStep; - -class Version3100Date20230324113141 extends SimpleMigrationStep { - /** - * @param IOutput $output - * @param Closure(): ISchemaWrapper $schemaClosure - * @param array $options - * @return null|ISchemaWrapper - */ - public function changeSchema(IOutput $output, Closure $schemaClosure, array $options): ?ISchemaWrapper { - /** @var ISchemaWrapper $schema */ - $schema = $schemaClosure(); - - $classifierTable = $schema->getTable('mail_classifiers'); - if (!$classifierTable->hasColumn('transformer_count')) { - $classifierTable->addColumn('transformer_count', Types::INTEGER, [ - 'notnull' => true, - 'default' => 0, - ]); - } - - return $schema; - } -} diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index e6262c62fa..bf4aa6f097 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -461,7 +461,7 @@ public function classifyImportance(Account $account, try { $pipeline = $this->persistenceService->loadLatest($account); } catch (ServiceException $e) { - $logger->warning('Failed to load importance classifier: ' . $e->getMessage(), [ + $logger->warning('Failed to load persisted estimator and extractor: ' . $e->getMessage(), [ 'exception' => $e, ]); } @@ -482,41 +482,17 @@ public function classifyImportance(Account $account, }, $messages) ); } - $messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']); - - // Load persisted transformers of the subject extractor. - // Is a bit hacky but a full abstraction would be overkill. - $transformers = $pipeline->getTransformers(); - if (count($transformers) === 2) { - $wordCountVectorizer = $transformers[0]; - if (!($wordCountVectorizer instanceof WordCountVectorizer)) { - throw new ServiceException("Failed to load persisted transformer: Expected " . WordCountVectorizer::class . ", got" . $wordCountVectorizer::class); - } - $tfidfTransformer = $transformers[1]; - if (!($tfidfTransformer instanceof TfIdfTransformer)) { - throw new ServiceException("Failed to load persisted transformer: Expected " . TfIdfTransformer::class . ", got" . $tfidfTransformer::class); - } - - $subjectExtractor = new SubjectExtractor(); - $subjectExtractor->setWordCountVectorizer($wordCountVectorizer); - $subjectExtractor->setTfidf($tfidfTransformer); - $extractor = new NewCompositeExtractor( - $this->vanillaExtractor, - $subjectExtractor, - ); - } else { - $logger->warning('Falling back to vanilla feature extractor'); - $extractor = $this->vanillaExtractor; - } + [$estimator, $extractor] = $pipeline; + $messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']); $features = $this->getFeaturesAndImportance( $account, $this->getIncomingMailboxes($account), $this->getOutgoingMailboxes($account), $messagesWithSender, - $extractor + $extractor, ); - $predictions = $pipeline->getEstimator()->predict( + $predictions = $estimator->predict( Unlabeled::build(array_column($features, 'features')) ); return array_combine( diff --git a/lib/Service/Classification/NewMessagesClassifier.php b/lib/Service/Classification/NewMessagesClassifier.php index 3b2fe4f6bb..ad9029c03a 100644 --- a/lib/Service/Classification/NewMessagesClassifier.php +++ b/lib/Service/Classification/NewMessagesClassifier.php @@ -52,7 +52,8 @@ public function __construct( private TagMapper $tagMapper, private LoggerInterface $logger, private IMailManager $mailManager, - private IUserPreferences $preferences) { + private IUserPreferences $preferences, + ) { } /** diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index 3006b3c93b..e779893e51 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -33,6 +33,10 @@ use OCA\Mail\Db\ClassifierMapper; use OCA\Mail\Exception\ServiceException; use OCA\Mail\Model\ClassifierPipeline; +use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\SubjectExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor; use OCP\App\IAppManager; use OCP\AppFramework\Db\DoesNotExistException; use OCP\AppFramework\Utility\ITimeFactory; @@ -42,14 +46,17 @@ use OCP\Files\NotPermittedException; use OCP\ICacheFactory; use OCP\ITempManager; +use Psr\Container\ContainerExceptionInterface; +use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; -use Rubix\ML\Estimator; use Rubix\ML\Learner; use Rubix\ML\Persistable; use Rubix\ML\PersistentModel; use Rubix\ML\Persisters\Filesystem; use Rubix\ML\Serializers\RBX; +use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\Transformer; +use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; use function file_get_contents; use function file_put_contents; @@ -80,13 +87,16 @@ class PersistenceService { /** @var LoggerInterface */ private $logger; + private ContainerInterface $container; + public function __construct(ClassifierMapper $mapper, IAppData $appData, ITempManager $tempManager, ITimeFactory $timeFactory, IAppManager $appManager, ICacheFactory $cacheFactory, - LoggerInterface $logger) { + LoggerInterface $logger, + ContainerInterface $container) { $this->mapper = $mapper; $this->appData = $appData; $this->tempManager = $tempManager; @@ -94,6 +104,7 @@ public function __construct(ClassifierMapper $mapper, $this->appManager = $appManager; $this->cacheFactory = $cacheFactory; $this->logger = $logger; + $this->container = $container; } /** @@ -189,8 +200,6 @@ public function persist(Classifier $classifier, $transformerIndex++; } - $classifier->setTransformerCount($transformerIndex); - /* * Now we set the model active so it can be used by the next request */ @@ -201,17 +210,29 @@ public function persist(Classifier $classifier, /** * @param Account $account * - * @return ?ClassifierPipeline + * @return ?array [Estimator, IExtractor] * * @throws ServiceException */ - public function loadLatest(Account $account): ?ClassifierPipeline { + public function loadLatest(Account $account): ?array { try { $latestModel = $this->mapper->findLatest($account->getId()); } catch (DoesNotExistException $e) { return null; } - return $this->load($latestModel); + + $pipeline = $this->load($latestModel); + try { + $extractor = $this->loadExtractor($latestModel, $pipeline); + } catch (ContainerExceptionInterface $e) { + throw new ServiceException( + "Failed to load extractor: {$e->getMessage()}", + 0, + $e, + ); + } + + return [$pipeline->getEstimator(), $extractor]; } /** @@ -223,8 +244,14 @@ public function loadLatest(Account $account): ?ClassifierPipeline { * @throws ServiceException */ public function load(Classifier $classifier): ClassifierPipeline { + $transformerCount = 0; + $appVersion = $this->parseAppVersion($classifier->getAppVersion()); + if ($appVersion[0] >= 3 && $appVersion[1] >= 2) { + $transformerCount = 2; + } + $id = $classifier->getId(); - $cached = $this->getCached($classifier->getId(), $classifier->getTransformerCount()); + $cached = $this->getCached($classifier->getId(), $transformerCount); if ($cached !== null) { $this->logger->debug("Using cached serialized classifier $id"); $serialized = $cached[0]; @@ -249,7 +276,7 @@ public function load(Classifier $classifier): ClassifierPipeline { $this->logger->debug("Serialized classifier loaded (size=$size)"); $serializedTransformers = []; - for ($i = 0; $i < $classifier->getTransformerCount(); $i++) { + for ($i = 0; $i < $transformerCount; $i++) { try { $transformerFile = $modelsFolder->getFile("{$id}_t$i"); } catch (NotFoundException $e) { @@ -300,6 +327,61 @@ public function load(Classifier $classifier): ClassifierPipeline { return new ClassifierPipeline($estimator, $transformers); } + /** + * Load and instantiate extractor based on a classifier's app version. + * + * @param Classifier $classifier + * @param ClassifierPipeline $pipeline + * @return IExtractor + * + * @throws ContainerExceptionInterface + * @throws ServiceException + */ + private function loadExtractor(Classifier $classifier, + ClassifierPipeline $pipeline): IExtractor { + $appVersion = $this->parseAppVersion($classifier->getAppVersion()); + if ($appVersion[0] >= 3 && $appVersion[1] >= 2) { + return $this->loadExtractorV2($pipeline->getTransformers()); + } + + return $this->loadExtractorV1($pipeline->getTransformers()); + } + + /** + * @return VanillaCompositeExtractor + * + * @throws ContainerExceptionInterface + */ + private function loadExtractorV1(): VanillaCompositeExtractor { + return $this->container->get(VanillaCompositeExtractor::class); + } + + /** + * @param Transformer[] $transformers + * @return NewCompositeExtractor + * + * @throws ContainerExceptionInterface + * @throws ServiceException + */ + private function loadExtractorV2(array $transformers): NewCompositeExtractor { + $wordCountVectorizer = $transformers[0]; + if (!($wordCountVectorizer instanceof WordCountVectorizer)) { + throw new ServiceException("Failed to load persisted transformer: Expected " . WordCountVectorizer::class . ", got" . $wordCountVectorizer::class); + } + $tfidfTransformer = $transformers[1]; + if (!($tfidfTransformer instanceof TfIdfTransformer)) { + throw new ServiceException("Failed to load persisted transformer: Expected " . TfIdfTransformer::class . ", got" . $tfidfTransformer::class); + } + + $subjectExtractor = new SubjectExtractor(); + $subjectExtractor->setWordCountVectorizer($wordCountVectorizer); + $subjectExtractor->setTfidf($tfidfTransformer); + return new NewCompositeExtractor( + $this->container->get(VanillaCompositeExtractor::class), + $subjectExtractor, + ); + } + private function getCacheKey(int $id): string { return "mail_classifier_$id"; } @@ -315,6 +397,9 @@ private function getTransformerCacheKey(int $id, int $index): string { * @return (?string)[]|null Array of serialized classifier and transformers */ private function getCached(int $id, int $transformerCount): ?array { + // FIXME: Will always return null as the cached, serialized data is always an empty string. + // See my note in self::cache() for further elaboration. + if (!$this->cacheFactory->isLocalCacheAvailable()) { return null; } @@ -335,6 +420,14 @@ private function getCached(int $id, int $transformerCount): ?array { } private function cache(int $id, string $serialized, array $serializedTransformers): void { + // FIXME: This is broken as some cache implementations will run the provided value through + // json_encode which drops non-utf8 strings. The serialized string contains binary + // data so an empty string will be saved instead (tested on Redis). + // Note: JSON requires strings to be valid utf8 (as per its spec). + + // IDEA: Implement a method ICache::setRaw() that forwards a raw/binary string as is to the + // underlying cache backend. + if (!$this->cacheFactory->isLocalCacheAvailable()) { return; } @@ -347,4 +440,18 @@ private function cache(int $id, string $serialized, array $serializedTransformer $transformerIndex++; } } + + /** + * Parse minor and major part of the given semver string. + * + * @return int[] + */ + private function parseAppVersion(string $version): array { + $parts = explode('.', $version); + if (count($parts) < 2) { + return [0, 0]; + } + + return [(int)$parts[0], (int)$parts[1]]; + } } diff --git a/lib/Service/Sync/ImapToDbSynchronizer.php b/lib/Service/Sync/ImapToDbSynchronizer.php index e05248a17a..f03857b26a 100644 --- a/lib/Service/Sync/ImapToDbSynchronizer.php +++ b/lib/Service/Sync/ImapToDbSynchronizer.php @@ -421,9 +421,6 @@ private function runPartialSync(Account $account, return $imapMessage->toDbMessage($mailbox->getId(), $account->getMailAccount()); }, $chunk); - // Ensure that the preview text is generated - //$dbMessages = $this->previewEnhancer->process($account, $mailbox, $dbMessages); - if ($importantTag) { $this->newMessagesClassifier->classifyNewMessages( $dbMessages,