Skip to content

Commit

Permalink
Refactor peristence
Browse files Browse the repository at this point in the history
  • Loading branch information
st3iny committed May 17, 2023
1 parent d91fe43 commit 909d31f
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 109 deletions.
11 changes: 0 additions & 11 deletions lib/Db/Classifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@
* @method int getCreatedAt()
* @method void setCreatedAt(int $createdAt)
*
* @method int getTransformerCount()
* @method void setTransformerCount(int $transformerCount)
*
* @method string|null getTransformers()
* @method void setTransformers(string|null $transformers)
*/
Expand Down Expand Up @@ -109,12 +106,6 @@ class Classifier extends Entity {
/** @var int */
protected $createdAt;

/** @var int */
protected $transformerCount;

/** @var string */
protected $transformers;

public function __construct() {
$this->addType('accountId', 'int');
$this->addType('type', 'string');
Expand All @@ -127,7 +118,5 @@ public function __construct() {
$this->addType('duration', 'int');
$this->addType('active', 'boolean');
$this->addType('createdAt', 'int');
$this->addType('transformerCount', 'int');
$this->addType('transformers', 'string');
}
}
56 changes: 0 additions & 56 deletions lib/Migration/Version3100Date20230324113141.php

This file was deleted.

34 changes: 5 additions & 29 deletions lib/Service/Classification/ImportanceClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ public function classifyImportance(Account $account,
try {
$pipeline = $this->persistenceService->loadLatest($account);
} catch (ServiceException $e) {
$logger->warning('Failed to load importance classifier: ' . $e->getMessage(), [
$logger->warning('Failed to load persisted estimator and extractor: ' . $e->getMessage(), [
'exception' => $e,
]);
}
Expand All @@ -482,41 +482,17 @@ public function classifyImportance(Account $account,
}, $messages)
);
}
$messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']);

// Load persisted transformers of the subject extractor.
// Is a bit hacky but a full abstraction would be overkill.
$transformers = $pipeline->getTransformers();
if (count($transformers) === 2) {
$wordCountVectorizer = $transformers[0];
if (!($wordCountVectorizer instanceof WordCountVectorizer)) {
throw new ServiceException("Failed to load persisted transformer: Expected " . WordCountVectorizer::class . ", got" . $wordCountVectorizer::class);
}
$tfidfTransformer = $transformers[1];
if (!($tfidfTransformer instanceof TfIdfTransformer)) {
throw new ServiceException("Failed to load persisted transformer: Expected " . TfIdfTransformer::class . ", got" . $tfidfTransformer::class);
}

$subjectExtractor = new SubjectExtractor();
$subjectExtractor->setWordCountVectorizer($wordCountVectorizer);
$subjectExtractor->setTfidf($tfidfTransformer);
$extractor = new NewCompositeExtractor(
$this->vanillaExtractor,
$subjectExtractor,
);
} else {
$logger->warning('Falling back to vanilla feature extractor');
$extractor = $this->vanillaExtractor;
}

[$estimator, $extractor] = $pipeline;
$messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']);
$features = $this->getFeaturesAndImportance(
$account,
$this->getIncomingMailboxes($account),
$this->getOutgoingMailboxes($account),
$messagesWithSender,
$extractor
$extractor,
);
$predictions = $pipeline->getEstimator()->predict(
$predictions = $estimator->predict(
Unlabeled::build(array_column($features, 'features'))
);
return array_combine(
Expand Down
3 changes: 2 additions & 1 deletion lib/Service/Classification/NewMessagesClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ public function __construct(
private TagMapper $tagMapper,
private LoggerInterface $logger,
private IMailManager $mailManager,
private IUserPreferences $preferences) {
private IUserPreferences $preferences,
) {
}

/**
Expand Down
125 changes: 116 additions & 9 deletions lib/Service/Classification/PersistenceService.php
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
use OCA\Mail\Db\ClassifierMapper;
use OCA\Mail\Exception\ServiceException;
use OCA\Mail\Model\ClassifierPipeline;
use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor;
use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor;
use OCA\Mail\Service\Classification\FeatureExtraction\SubjectExtractor;
use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor;
use OCP\App\IAppManager;
use OCP\AppFramework\Db\DoesNotExistException;
use OCP\AppFramework\Utility\ITimeFactory;
Expand All @@ -42,14 +46,17 @@
use OCP\Files\NotPermittedException;
use OCP\ICacheFactory;
use OCP\ITempManager;
use Psr\Container\ContainerExceptionInterface;
use Psr\Container\ContainerInterface;
use Psr\Log\LoggerInterface;
use Rubix\ML\Estimator;
use Rubix\ML\Learner;
use Rubix\ML\Persistable;
use Rubix\ML\PersistentModel;
use Rubix\ML\Persisters\Filesystem;
use Rubix\ML\Serializers\RBX;
use Rubix\ML\Transformers\TfIdfTransformer;
use Rubix\ML\Transformers\Transformer;
use Rubix\ML\Transformers\WordCountVectorizer;
use RuntimeException;
use function file_get_contents;
use function file_put_contents;
Expand Down Expand Up @@ -80,20 +87,24 @@ class PersistenceService {
/** @var LoggerInterface */
private $logger;

private ContainerInterface $container;

public function __construct(ClassifierMapper $mapper,
IAppData $appData,
ITempManager $tempManager,
ITimeFactory $timeFactory,
IAppManager $appManager,
ICacheFactory $cacheFactory,
LoggerInterface $logger) {
LoggerInterface $logger,
ContainerInterface $container) {
$this->mapper = $mapper;
$this->appData = $appData;
$this->tempManager = $tempManager;
$this->timeFactory = $timeFactory;
$this->appManager = $appManager;
$this->cacheFactory = $cacheFactory;
$this->logger = $logger;
$this->container = $container;
}

/**
Expand Down Expand Up @@ -189,8 +200,6 @@ public function persist(Classifier $classifier,
$transformerIndex++;
}

$classifier->setTransformerCount($transformerIndex);

/*
* Now we set the model active so it can be used by the next request
*/
Expand All @@ -201,17 +210,29 @@ public function persist(Classifier $classifier,
/**
* @param Account $account
*
* @return ?ClassifierPipeline
* @return ?array [Estimator, IExtractor]
*
* @throws ServiceException
*/
public function loadLatest(Account $account): ?ClassifierPipeline {
public function loadLatest(Account $account): ?array {
try {
$latestModel = $this->mapper->findLatest($account->getId());
} catch (DoesNotExistException $e) {
return null;
}
return $this->load($latestModel);

$pipeline = $this->load($latestModel);
try {
$extractor = $this->loadExtractor($latestModel, $pipeline);
} catch (ContainerExceptionInterface $e) {
throw new ServiceException(
"Failed to load extractor: {$e->getMessage()}",
0,
$e,
);
}

return [$pipeline->getEstimator(), $extractor];
}

/**
Expand All @@ -223,8 +244,14 @@ public function loadLatest(Account $account): ?ClassifierPipeline {
* @throws ServiceException
*/
public function load(Classifier $classifier): ClassifierPipeline {
$transformerCount = 0;
$appVersion = $this->parseAppVersion($classifier->getAppVersion());
if ($appVersion[0] >= 3 && $appVersion[1] >= 2) {
$transformerCount = 2;
}

$id = $classifier->getId();
$cached = $this->getCached($classifier->getId(), $classifier->getTransformerCount());
$cached = $this->getCached($classifier->getId(), $transformerCount);
if ($cached !== null) {
$this->logger->debug("Using cached serialized classifier $id");
$serialized = $cached[0];
Expand All @@ -249,7 +276,7 @@ public function load(Classifier $classifier): ClassifierPipeline {
$this->logger->debug("Serialized classifier loaded (size=$size)");

$serializedTransformers = [];
for ($i = 0; $i < $classifier->getTransformerCount(); $i++) {
for ($i = 0; $i < $transformerCount; $i++) {
try {
$transformerFile = $modelsFolder->getFile("{$id}_t$i");
} catch (NotFoundException $e) {
Expand Down Expand Up @@ -300,6 +327,61 @@ public function load(Classifier $classifier): ClassifierPipeline {
return new ClassifierPipeline($estimator, $transformers);
}

/**
* Load and instantiate extractor based on a classifier's app version.
*
* @param Classifier $classifier
* @param ClassifierPipeline $pipeline
* @return IExtractor
*
* @throws ContainerExceptionInterface
* @throws ServiceException
*/
private function loadExtractor(Classifier $classifier,
ClassifierPipeline $pipeline): IExtractor {
$appVersion = $this->parseAppVersion($classifier->getAppVersion());
if ($appVersion[0] >= 3 && $appVersion[1] >= 2) {
return $this->loadExtractorV2($pipeline->getTransformers());
}

return $this->loadExtractorV1($pipeline->getTransformers());
}

/**
* @return VanillaCompositeExtractor
*
* @throws ContainerExceptionInterface
*/
private function loadExtractorV1(): VanillaCompositeExtractor {
return $this->container->get(VanillaCompositeExtractor::class);
}

/**
* @param Transformer[] $transformers
* @return NewCompositeExtractor
*
* @throws ContainerExceptionInterface
* @throws ServiceException
*/
private function loadExtractorV2(array $transformers): NewCompositeExtractor {
$wordCountVectorizer = $transformers[0];
if (!($wordCountVectorizer instanceof WordCountVectorizer)) {
throw new ServiceException("Failed to load persisted transformer: Expected " . WordCountVectorizer::class . ", got" . $wordCountVectorizer::class);
}
$tfidfTransformer = $transformers[1];
if (!($tfidfTransformer instanceof TfIdfTransformer)) {
throw new ServiceException("Failed to load persisted transformer: Expected " . TfIdfTransformer::class . ", got" . $tfidfTransformer::class);
}

$subjectExtractor = new SubjectExtractor();
$subjectExtractor->setWordCountVectorizer($wordCountVectorizer);
$subjectExtractor->setTfidf($tfidfTransformer);
return new NewCompositeExtractor(
$this->container->get(VanillaCompositeExtractor::class),
$subjectExtractor,
);
}

private function getCacheKey(int $id): string {
return "mail_classifier_$id";
}
Expand All @@ -315,6 +397,9 @@ private function getTransformerCacheKey(int $id, int $index): string {
* @return (?string)[]|null Array of serialized classifier and transformers
*/
private function getCached(int $id, int $transformerCount): ?array {
// FIXME: Will always return null as the cached, serialized data is always an empty string.
// See my note in self::cache() for further elaboration.

if (!$this->cacheFactory->isLocalCacheAvailable()) {
return null;
}
Expand All @@ -335,6 +420,14 @@ private function getCached(int $id, int $transformerCount): ?array {
}

private function cache(int $id, string $serialized, array $serializedTransformers): void {
// FIXME: This is broken as some cache implementations will run the provided value through
// json_encode which drops non-utf8 strings. The serialized string contains binary
// data so an empty string will be saved instead (tested on Redis).
// Note: JSON requires strings to be valid utf8 (as per its spec).

// IDEA: Implement a method ICache::setRaw() that forwards a raw/binary string as is to the
// underlying cache backend.

if (!$this->cacheFactory->isLocalCacheAvailable()) {
return;
}
Expand All @@ -347,4 +440,18 @@ private function cache(int $id, string $serialized, array $serializedTransformer
$transformerIndex++;
}
}

/**
* Parse minor and major part of the given semver string.
*
* @return int[]
*/
private function parseAppVersion(string $version): array {
$parts = explode('.', $version);
if (count($parts) < 2) {
return [0, 0];
}

return [(int)$parts[0], (int)$parts[1]];
}
}
3 changes: 0 additions & 3 deletions lib/Service/Sync/ImapToDbSynchronizer.php
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,6 @@ private function runPartialSync(Account $account,
return $imapMessage->toDbMessage($mailbox->getId(), $account->getMailAccount());
}, $chunk);

// Ensure that the preview text is generated
//$dbMessages = $this->previewEnhancer->process($account, $mailbox, $dbMessages);

if ($importantTag) {
$this->newMessagesClassifier->classifyNewMessages(
$dbMessages,
Expand Down

0 comments on commit 909d31f

Please sign in to comment.