5
5
import json
6
6
import os
7
7
import re
8
+ import sqlite3
8
9
from abc import ABC
9
10
from collections import deque
10
11
from html import unescape as unescape_html
11
12
from shutil import rmtree
12
13
from typing import Any , Callable , Iterable , Optional
13
14
from xml .sax .saxutils import unescape as unescape_xml
14
- import sqlite3
15
15
16
16
import dill as pickle
17
17
import lz4 .frame
18
+ import msgspec
18
19
import numpy as np
19
20
import spacy
20
21
import torch
21
- from recordclass import dataobject
22
+ from msgspec import field
22
23
from scipy .sparse import csr_matrix
23
24
from sentence_transformers import SentenceTransformer , util
24
25
from sklearn .feature_extraction .text import TfidfVectorizer
33
34
PHILO_TEXT_OBJECT_LEVELS = {"doc" : 1 , "div1" : 2 , "div2" : 3 , "div3" : 4 , "para" : 5 , "sent" : 6 , "word" : 7 }
34
35
TEMP_DIR = os .getcwd ()
35
36
36
-
37
- class PassageGroup (dataobject , fast_new = True ):
37
+ class PassageGroup (msgspec .Struct , array_like = True ):
38
38
"""Text passage with all associated properties and vector representation"""
39
39
40
40
start_byte : int = 0
@@ -43,13 +43,15 @@ class PassageGroup(dataobject, fast_new=True):
43
43
metadata : dict = {}
44
44
45
45
46
- class MergedGroup (dataobject , fast_new = True ):
46
+ class MergedGroup (msgspec . Struct , array_like = True ):
47
47
"""A source and target PassageGroup pair with similarity"""
48
48
49
- source : PassageGroup = PassageGroup ( )
50
- target : PassageGroup = PassageGroup ( )
49
+ source : PassageGroup = field ( default_factory = PassageGroup )
50
+ target : PassageGroup = field ( default_factory = PassageGroup )
51
51
similarity : float = 0.0
52
52
53
+ ENCODER = msgspec .msgpack .Encoder ()
54
+ DECODER = msgspec .msgpack .Decoder (type = MergedGroup )
53
55
54
56
class DocumentChunks :
55
57
"""A generator with caching"""
@@ -163,17 +165,21 @@ def __init__(self, matches: Iterable[MergedGroup]):
163
165
self .is_cached = True
164
166
self .count = self .__save (matches ) # save generator to disk
165
167
166
- def extend (self , new_matches : Iterable [MergedGroup ]):
167
- """Add new matches to existing matches"""
168
+ def match_generator (self , new_matches ):
168
169
for match in new_matches :
169
- dump = pickle . dumps (match )
170
- self . cursor . execute ( "INSERT INTO matches VALUES (?, ?)" , ( self .count , dump ) )
170
+ dump = ENCODER . encode (match )
171
+ yield ( self .count , dump )
171
172
self .count += 1
172
173
174
+ def extend (self , new_matches : Iterable [MergedGroup ]):
175
+ """Add new matches to existing matches"""
176
+ encoded_matches = self .match_generator (new_matches )
177
+ self .cursor .executemany ("INSERT INTO matches VALUES (?, ?)" , encoded_matches )
178
+
173
179
def __save (self , matches ):
174
180
count = 0
175
181
for count , match in enumerate (matches ):
176
- dump = pickle . dumps (match )
182
+ dump = ENCODER . encode (match )
177
183
self .cursor .execute ("INSERT INTO matches VALUES (?, ?)" , (self .count , dump ))
178
184
if count == 0 :
179
185
return 0
@@ -193,7 +199,7 @@ def load(cls):
193
199
cursor = conn .cursor ()
194
200
cursor .execute ("SELECT match from matches ORDER BY match_id" )
195
201
for match in cursor :
196
- matches .append (pickle . loads (match [0 ]))
202
+ matches .append (DECODER . decode (match [0 ]))
197
203
conn .close ()
198
204
return cls (matches )
199
205
@@ -207,7 +213,7 @@ def __iter__(self):
207
213
else :
208
214
self .cursor .execute ("SELECT match FROM matches ORDER BY match_id" )
209
215
for match in self .cursor :
210
- yield pickle . loads (match [0 ])
216
+ yield DECODER . decode (match [0 ])
211
217
212
218
213
219
class Corpus (ABC ):
@@ -390,7 +396,7 @@ def process_inner_compare(self, results, min_similarity: float, outer_start_inde
390
396
self .metadata [inner_doc_id ]["filename" ],
391
397
self .metadata [inner_doc_id ],
392
398
),
393
- results [outer_doc_id , inner_doc_id ], # type: ignore
399
+ float ( results [outer_doc_id , inner_doc_id ]) , # type: ignore
394
400
)
395
401
396
402
def process_outer_compare (
@@ -413,7 +419,7 @@ def process_outer_compare(
413
419
target_corpus .metadata [inner_index ]["filename" ],
414
420
target_corpus .metadata [inner_index ],
415
421
),
416
- results [outer_doc_id , inner_doc_id ], # type: ignore
422
+ float ( results [outer_doc_id , inner_doc_id ]) , # type: ignore
417
423
)
418
424
419
425
@@ -903,7 +909,7 @@ def run_vsa(source_path: str, target_path: str, workers: int, config: dict[str,
903
909
config ["source" ]["strip_punctuation" ] = False
904
910
config ["target" ]["strip_punctuation" ] = False
905
911
source_preproc = PreProcessor (is_philo_db = True , workers = workers , ** config ["source" ])
906
- target_preproc = PreProcessor (is_philo_db = True , workers = workers , nlp_model = source_preproc .nlp , ** config ["target" ])
912
+ target_preproc = PreProcessor (is_philo_db = True , workers = workers , nlp_model = source_preproc .nlp , using_gpu = source_preproc . using_gpu , ** config ["target" ])
907
913
source_texts : Iterable [Tokens ] = source_preproc .process_texts (
908
914
(file .path for file in os .scandir (source_path )), keep_all = True , progress = False
909
915
)
0 commit comments