Skip to content

Commit 3fb5de0

Browse files
committed
caching optimization for vsa
1 parent 56e213b commit 3fb5de0

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

lib/textpair/generate_ngrams.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
import configparser
55
import os
6+
import sqlite3
67
from collections import defaultdict
78
from glob import glob
89
from typing import Any, Dict, List, Tuple
9-
import sqlite3
1010

1111
import orjson
1212
from mmh3 import hash as hash32

lib/textpair/vector_space_aligner.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,21 @@
55
import json
66
import os
77
import re
8+
import sqlite3
89
from abc import ABC
910
from collections import deque
1011
from html import unescape as unescape_html
1112
from shutil import rmtree
1213
from typing import Any, Callable, Iterable, Optional
1314
from xml.sax.saxutils import unescape as unescape_xml
14-
import sqlite3
1515

1616
import dill as pickle
1717
import lz4.frame
18+
import msgspec
1819
import numpy as np
1920
import spacy
2021
import torch
21-
from recordclass import dataobject
22+
from msgspec import field
2223
from scipy.sparse import csr_matrix
2324
from sentence_transformers import SentenceTransformer, util
2425
from sklearn.feature_extraction.text import TfidfVectorizer
@@ -33,8 +34,7 @@
3334
PHILO_TEXT_OBJECT_LEVELS = {"doc": 1, "div1": 2, "div2": 3, "div3": 4, "para": 5, "sent": 6, "word": 7}
3435
TEMP_DIR = os.getcwd()
3536

36-
37-
class PassageGroup(dataobject, fast_new=True):
37+
class PassageGroup(msgspec.Struct, array_like=True):
3838
"""Text passage with all associated properties and vector representation"""
3939

4040
start_byte: int = 0
@@ -43,13 +43,15 @@ class PassageGroup(dataobject, fast_new=True):
4343
metadata: dict = {}
4444

4545

46-
class MergedGroup(dataobject, fast_new=True):
46+
class MergedGroup(msgspec.Struct, array_like=True):
4747
"""A source and target PassageGroup pair with similarity"""
4848

49-
source: PassageGroup = PassageGroup()
50-
target: PassageGroup = PassageGroup()
49+
source: PassageGroup = field(default_factory=PassageGroup)
50+
target: PassageGroup = field(default_factory=PassageGroup)
5151
similarity: float = 0.0
5252

53+
ENCODER = msgspec.msgpack.Encoder()
54+
DECODER = msgspec.msgpack.Decoder(type=MergedGroup)
5355

5456
class DocumentChunks:
5557
"""A generator with caching"""
@@ -163,17 +165,21 @@ def __init__(self, matches: Iterable[MergedGroup]):
163165
self.is_cached = True
164166
self.count = self.__save(matches) # save generator to disk
165167

166-
def extend(self, new_matches: Iterable[MergedGroup]):
167-
"""Add new matches to existing matches"""
168+
def match_generator(self, new_matches):
168169
for match in new_matches:
169-
dump = pickle.dumps(match)
170-
self.cursor.execute("INSERT INTO matches VALUES (?, ?)", (self.count, dump))
170+
dump = ENCODER.encode(match)
171+
yield (self.count, dump)
171172
self.count += 1
172173

174+
def extend(self, new_matches: Iterable[MergedGroup]):
175+
"""Add new matches to existing matches"""
176+
encoded_matches = self.match_generator(new_matches)
177+
self.cursor.executemany("INSERT INTO matches VALUES (?, ?)", encoded_matches)
178+
173179
def __save(self, matches):
174180
count = 0
175181
for count, match in enumerate(matches):
176-
dump = pickle.dumps(match)
182+
dump = ENCODER.encode(match)
177183
self.cursor.execute("INSERT INTO matches VALUES (?, ?)", (self.count, dump))
178184
if count == 0:
179185
return 0
@@ -193,7 +199,7 @@ def load(cls):
193199
cursor = conn.cursor()
194200
cursor.execute("SELECT match from matches ORDER BY match_id")
195201
for match in cursor:
196-
matches.append(pickle.loads(match[0]))
202+
matches.append(DECODER.decode(match[0]))
197203
conn.close()
198204
return cls(matches)
199205

@@ -207,7 +213,7 @@ def __iter__(self):
207213
else:
208214
self.cursor.execute("SELECT match FROM matches ORDER BY match_id")
209215
for match in self.cursor:
210-
yield pickle.loads(match[0])
216+
yield DECODER.decode(match[0])
211217

212218

213219
class Corpus(ABC):
@@ -390,7 +396,7 @@ def process_inner_compare(self, results, min_similarity: float, outer_start_inde
390396
self.metadata[inner_doc_id]["filename"],
391397
self.metadata[inner_doc_id],
392398
),
393-
results[outer_doc_id, inner_doc_id], # type: ignore
399+
float(results[outer_doc_id, inner_doc_id]), # type: ignore
394400
)
395401

396402
def process_outer_compare(
@@ -413,7 +419,7 @@ def process_outer_compare(
413419
target_corpus.metadata[inner_index]["filename"],
414420
target_corpus.metadata[inner_index],
415421
),
416-
results[outer_doc_id, inner_doc_id], # type: ignore
422+
float(results[outer_doc_id, inner_doc_id]), # type: ignore
417423
)
418424

419425

@@ -903,7 +909,7 @@ def run_vsa(source_path: str, target_path: str, workers: int, config: dict[str,
903909
config["source"]["strip_punctuation"] = False
904910
config["target"]["strip_punctuation"] = False
905911
source_preproc = PreProcessor(is_philo_db=True, workers=workers, **config["source"])
906-
target_preproc = PreProcessor(is_philo_db=True, workers=workers, nlp_model=source_preproc.nlp, **config["target"])
912+
target_preproc = PreProcessor(is_philo_db=True, workers=workers, nlp_model=source_preproc.nlp, using_gpu=source_preproc.using_gpu, **config["target"])
907913
source_texts: Iterable[Tokens] = source_preproc.process_texts(
908914
(file.path for file in os.scandir(source_path)), keep_all=True, progress=False
909915
)

web-app/src/components/searchResults.vue

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
</div>
77
<report-switcher />
88
<div class="row">
9-
<div class="col position-relative">
9+
<div class="col-9 position-relative">
1010
<div class="d-flex justify-content-center position-relative" v-if="loading">
1111
<div class="spinner-border"
1212
style="width: 8rem; height: 8rem; position: absolute; z-index: 50; top: 30px" role="status">
@@ -65,7 +65,8 @@
6565
</div>
6666
</div>
6767
</div>
68-
<div class="loading position-absolute" style="left: 50%; transform: translateX(-50%)" v-if="facetLoading">
68+
<div class="loading position-absolute" style="left: 50%; transform: translateX(-50%)"
69+
v-if="facetLoading">
6970
<div class="d-flex justify-content-center position-relative">
7071
<div class="spinner-border"
7172
style="width: 4rem; height: 4rem; position: absolute; z-index: 50; top: 30px" role="status">
@@ -100,10 +101,10 @@
100101
</template>
101102

102103
<script>
103-
import searchArguments from "./searchArguments";
104+
import Velocity from "velocity-animate";
104105
import passagePair from "./passagePair";
105106
import reportSwitcher from "./reportSwitcher";
106-
import Velocity from "velocity-animate";
107+
import searchArguments from "./searchArguments";
107108
108109
export default {
109110
name: "searchResults",

0 commit comments

Comments
 (0)