-
-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #717 from djukicn/semantic-search
Semantic search server communicator
- Loading branch information
Showing
2 changed files
with
257 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import json | ||
import base64 | ||
import zlib | ||
import sys | ||
import re | ||
from typing import Any, Optional, List, Optional, Tuple, Callable | ||
|
||
import numpy as np | ||
|
||
from Orange.misc.server_embedder import ServerEmbedderCommunicator | ||
from orangecontrib.text import Corpus | ||
from orangecontrib.text.misc import wait_nltk_data | ||
|
||
|
||
MAX_PACKAGE_SIZE = 50000 | ||
MIN_CHUNKS = 20 | ||
|
||
|
||
class SemanticSearch: | ||
|
||
def __init__(self) -> None: | ||
self._server_communicator = _ServerCommunicator( | ||
model_name='semantic-search', | ||
max_parallel_requests=100, | ||
server_url='https://apiv2.garaza.io', | ||
embedder_type='text', | ||
) | ||
|
||
def __call__( | ||
self, texts: List[str], queries: List[str], | ||
progress_callback: Optional[Callable] = None | ||
) -> List[Optional[List[Tuple[Tuple[int, int], float]]]]: | ||
"""Computes matches for given documents and queries. | ||
Parameters | ||
---------- | ||
texts: List[str] | ||
A list of raw texts to be matched. | ||
queries: List[str] | ||
A list of query words/phrases. | ||
Returns | ||
------- | ||
List[Optional[List[Tuple[Tuple[int, int], float]]]] | ||
The elements of the outer list represent each document. The entries | ||
are either None or lists of matches. Entries of each list of matches | ||
are matches for each sentence. Each match is of the form | ||
((start_idx, end_idx), score). Note that tuples are actually | ||
converted to lists before the result is returned. | ||
""" | ||
|
||
if len(texts) == 0 or len(queries) == 0: | ||
return [None] * len(texts) | ||
|
||
chunks = list() | ||
chunk = list() | ||
skipped = list() | ||
queries_enc = base64.b64encode( | ||
zlib.compress( | ||
json.dumps(queries).encode('utf-8', 'replace'), | ||
level=-1 | ||
) | ||
).decode('utf-8', 'replace') | ||
|
||
encoded_texts = list() | ||
sizes = list() | ||
chunks = list() | ||
for i, text in enumerate(texts): | ||
encoded = base64.b64encode(zlib.compress( | ||
text.encode('utf-8', 'replace'), level=-1) | ||
).decode('utf-8', 'replace') | ||
size = sys.getsizeof(encoded) | ||
if size > MAX_PACKAGE_SIZE: | ||
skipped.append(i) | ||
continue | ||
encoded_texts.append(encoded) | ||
sizes.append(size) | ||
|
||
chunks_ = self._make_chunks(encoded_texts, sizes) | ||
for chunk in chunks_: | ||
chunks.append([chunk, queries_enc]) | ||
|
||
result_ = self._server_communicator.embedd_data( | ||
chunks, processed_callback=progress_callback, | ||
) | ||
if result_ is None: | ||
return [None] * len(texts) | ||
|
||
result = list() | ||
for chunk in result_: | ||
result.extend(chunk) | ||
|
||
results = list() | ||
idx = 0 | ||
for i in range(len(texts)): | ||
if i in skipped: | ||
results.append(None) | ||
else: | ||
results.append(result[idx]) | ||
idx += 1 | ||
|
||
return results | ||
|
||
def _make_chunks(self, encoded_texts, sizes, depth=0): | ||
chunks = np.array_split(encoded_texts, MIN_CHUNKS if depth == 0 else 2) | ||
chunk_sizes = np.array_split(sizes, MIN_CHUNKS if depth == 0 else 2) | ||
result = list() | ||
for i in range(len(chunks)): | ||
if np.sum(chunk_sizes[i]) > MAX_PACKAGE_SIZE: | ||
result.extend(self._make_chunks(chunks[i], chunk_sizes[i], depth + 1)) | ||
else: | ||
result.append(chunks[i]) | ||
return [list(r) for r in result if len(r) > 0] | ||
|
||
def set_cancelled(self): | ||
if hasattr(self, '_server_communicator'): | ||
self._server_communicator.set_cancelled() | ||
|
||
def clear_cache(self): | ||
if self._server_communicator: | ||
self._server_communicator.clear_cache() | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, ex_type, value, traceback): | ||
self.set_cancelled() | ||
|
||
def __del__(self): | ||
self.__exit__(None, None, None) | ||
|
||
|
||
class _ServerCommunicator(ServerEmbedderCommunicator): | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.content_type = 'application/json' | ||
|
||
async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]: | ||
return json.dumps(data_instance).encode('utf-8', 'replace') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import unittest | ||
from unittest.mock import patch | ||
from collections.abc import Iterator | ||
import asyncio | ||
import numpy as np | ||
|
||
from orangecontrib.text.semantic_search import ( | ||
SemanticSearch, | ||
MIN_CHUNKS, | ||
MAX_PACKAGE_SIZE | ||
) | ||
from orangecontrib.text import Corpus | ||
|
||
PATCH_METHOD = 'httpx.AsyncClient.post' | ||
QUERIES = ['test query', 'another test query'] | ||
RESPONSE = [ | ||
b'{ "embedding": [[[[0, 57], 0.22114424407482147]]] }', | ||
b'{ "embedding": [[[[0, 57], 0.5597518086433411]]] }', | ||
b'{ "embedding": [[[[0, 40], 0.11774948984384537]]] }', | ||
b'{ "embedding": [[[[0, 50], 0.2228381633758545]]] }', | ||
b'{ "embedding": [[[[0, 61], 0.19825558364391327]]] }', | ||
b'{ "embedding": [[[[0, 47], 0.19025272130966187]]] }', | ||
b'{ "embedding": [[[[0, 40], 0.09688498824834824]]] }', | ||
b'{ "embedding": [[[[0, 55], 0.2982504367828369]]] }', | ||
b'{ "embedding": [[[[0, 12], 0.2982504367828369]]] }', | ||
] | ||
IDEAL_RESPONSE = [ | ||
[[[0, 57], 0.22114424407482147]], | ||
[[[0, 57], 0.5597518086433411]], | ||
[[[0, 40], 0.11774948984384537]], | ||
[[[0, 50], 0.2228381633758545]], | ||
[[[0, 61], 0.19825558364391327]], | ||
[[[0, 47], 0.19025272130966187]], | ||
[[[0, 40], 0.09688498824834824]], | ||
[[[0, 55], 0.2982504367828369]], | ||
[[[0, 12], 0.2982504367828369]] | ||
] | ||
|
||
|
||
class DummyResponse: | ||
|
||
def __init__(self, content): | ||
self.content = content | ||
|
||
|
||
def make_dummy_post(response, sleep=0): | ||
@staticmethod | ||
async def dummy_post(url, headers, data): | ||
await asyncio.sleep(sleep) | ||
return DummyResponse( | ||
content=next(response) if isinstance(response, Iterator) else response | ||
) | ||
return dummy_post | ||
|
||
|
||
class SemanticSearchTest(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.semantic_search = SemanticSearch() | ||
self.corpus = Corpus.from_file('deerwester') | ||
|
||
def tearDown(self): | ||
self.semantic_search.clear_cache() | ||
|
||
def test_make_chunks_small(self): | ||
chunks = self.semantic_search._make_chunks( | ||
self.corpus.documents, [100] * len(self.corpus.documents) | ||
) | ||
self.assertEqual(len(chunks), min(len(self.corpus.documents), MIN_CHUNKS)) | ||
|
||
def test_make_chunks_medium(self): | ||
num_docs = len(self.corpus.documents) | ||
documents = self.corpus.documents | ||
if num_docs < MIN_CHUNKS: | ||
documents = [documents[0]] * MIN_CHUNKS | ||
chunks = self.semantic_search._make_chunks( | ||
documents, [MAX_PACKAGE_SIZE / MIN_CHUNKS - 1] * len(documents) | ||
) | ||
self.assertEqual(len(chunks), MIN_CHUNKS) | ||
|
||
def test_make_chunks_large(self): | ||
num_docs = len(self.corpus.documents) | ||
documents = self.corpus.documents | ||
if num_docs < MIN_CHUNKS: | ||
documents = [documents[0]] * MIN_CHUNKS * 100 | ||
mps = MAX_PACKAGE_SIZE | ||
chunks = self.semantic_search._make_chunks( | ||
documents, | ||
[mps / 100] * (len(documents) - 2) + [0.3 * mps, 0.9 * mps, mps] | ||
) | ||
self.assertGreater(len(chunks), MIN_CHUNKS) | ||
|
||
@patch(PATCH_METHOD) | ||
def test_empty_corpus(self, mock): | ||
self.assertEqual( | ||
len(self.semantic_search(self.corpus.documents[:0], QUERIES)), 0 | ||
) | ||
mock.request.assert_not_called() | ||
mock.get_response.assert_not_called() | ||
self.assertEqual( | ||
self.semantic_search._server_communicator._cache._cache_dict, | ||
dict() | ||
) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE))) | ||
def test_success(self): | ||
result = self.semantic_search(self.corpus.documents, QUERIES) | ||
self.assertEqual(result, IDEAL_RESPONSE) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(RESPONSE[0])) | ||
def test_success_chunks(self): | ||
num_docs = len(self.corpus.documents) | ||
documents = self.corpus.documents | ||
if num_docs < MIN_CHUNKS: | ||
documents = [documents[0]] * MIN_CHUNKS | ||
result = self.semantic_search(documents, QUERIES) | ||
self.assertEqual(len(result), MIN_CHUNKS) |