Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semantic search server communicator #717

Merged
merged 1 commit into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions orangecontrib/text/semantic_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import json
import base64
import zlib
import sys
import re
from typing import Any, Optional, List, Optional, Tuple, Callable

import numpy as np

from Orange.misc.server_embedder import ServerEmbedderCommunicator
from orangecontrib.text import Corpus
from orangecontrib.text.misc import wait_nltk_data


MAX_PACKAGE_SIZE = 50000
MIN_CHUNKS = 20


class SemanticSearch:

def __init__(self) -> None:
self._server_communicator = _ServerCommunicator(
model_name='semantic-search',
max_parallel_requests=100,
server_url='https://apiv2.garaza.io',
embedder_type='text',
)

def __call__(
self, texts: List[str], queries: List[str],
progress_callback: Optional[Callable] = None
) -> List[Optional[List[Tuple[Tuple[int, int], float]]]]:
"""Computes matches for given documents and queries.

Parameters
----------
texts: List[str]
A list of raw texts to be matched.
queries: List[str]
A list of query words/phrases.

Returns
-------
List[Optional[List[Tuple[Tuple[int, int], float]]]]
The elements of the outer list represent each document. The entries
are either None or lists of matches. Entries of each list of matches
are matches for each sentence. Each match is of the form
((start_idx, end_idx), score). Note that tuples are actually
converted to lists before the result is returned.
"""

if len(texts) == 0 or len(queries) == 0:
return [None] * len(texts)

chunks = list()
chunk = list()
skipped = list()
queries_enc = base64.b64encode(
zlib.compress(
json.dumps(queries).encode('utf-8', 'replace'),
level=-1
)
).decode('utf-8', 'replace')

encoded_texts = list()
sizes = list()
chunks = list()
for i, text in enumerate(texts):
encoded = base64.b64encode(zlib.compress(
text.encode('utf-8', 'replace'), level=-1)
).decode('utf-8', 'replace')
size = sys.getsizeof(encoded)
if size > MAX_PACKAGE_SIZE:
skipped.append(i)
continue
encoded_texts.append(encoded)
sizes.append(size)

chunks_ = self._make_chunks(encoded_texts, sizes)
for chunk in chunks_:
chunks.append([chunk, queries_enc])

result_ = self._server_communicator.embedd_data(
chunks, processed_callback=progress_callback,
)
if result_ is None:
return [None] * len(texts)

result = list()
for chunk in result_:
result.extend(chunk)

results = list()
idx = 0
for i in range(len(texts)):
if i in skipped:
results.append(None)
else:
results.append(result[idx])
idx += 1

return results

def _make_chunks(self, encoded_texts, sizes, depth=0):
chunks = np.array_split(encoded_texts, MIN_CHUNKS if depth == 0 else 2)
chunk_sizes = np.array_split(sizes, MIN_CHUNKS if depth == 0 else 2)
result = list()
for i in range(len(chunks)):
if np.sum(chunk_sizes[i]) > MAX_PACKAGE_SIZE:
result.extend(self._make_chunks(chunks[i], chunk_sizes[i], depth + 1))
else:
result.append(chunks[i])
return [list(r) for r in result if len(r) > 0]

def set_cancelled(self):
if hasattr(self, '_server_communicator'):
self._server_communicator.set_cancelled()

def clear_cache(self):
if self._server_communicator:
self._server_communicator.clear_cache()

def __enter__(self):
return self

def __exit__(self, ex_type, value, traceback):
self.set_cancelled()

def __del__(self):
self.__exit__(None, None, None)


class _ServerCommunicator(ServerEmbedderCommunicator):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.content_type = 'application/json'

async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
return json.dumps(data_instance).encode('utf-8', 'replace')
117 changes: 117 additions & 0 deletions orangecontrib/text/tests/test_semantic_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import unittest
from unittest.mock import patch
from collections.abc import Iterator
import asyncio
import numpy as np

from orangecontrib.text.semantic_search import (
SemanticSearch,
MIN_CHUNKS,
MAX_PACKAGE_SIZE
)
from orangecontrib.text import Corpus

PATCH_METHOD = 'httpx.AsyncClient.post'
QUERIES = ['test query', 'another test query']
RESPONSE = [
b'{ "embedding": [[[[0, 57], 0.22114424407482147]]] }',
b'{ "embedding": [[[[0, 57], 0.5597518086433411]]] }',
b'{ "embedding": [[[[0, 40], 0.11774948984384537]]] }',
b'{ "embedding": [[[[0, 50], 0.2228381633758545]]] }',
b'{ "embedding": [[[[0, 61], 0.19825558364391327]]] }',
b'{ "embedding": [[[[0, 47], 0.19025272130966187]]] }',
b'{ "embedding": [[[[0, 40], 0.09688498824834824]]] }',
b'{ "embedding": [[[[0, 55], 0.2982504367828369]]] }',
b'{ "embedding": [[[[0, 12], 0.2982504367828369]]] }',
]
IDEAL_RESPONSE = [
[[[0, 57], 0.22114424407482147]],
[[[0, 57], 0.5597518086433411]],
[[[0, 40], 0.11774948984384537]],
[[[0, 50], 0.2228381633758545]],
[[[0, 61], 0.19825558364391327]],
[[[0, 47], 0.19025272130966187]],
[[[0, 40], 0.09688498824834824]],
[[[0, 55], 0.2982504367828369]],
[[[0, 12], 0.2982504367828369]]
]


class DummyResponse:

def __init__(self, content):
self.content = content


def make_dummy_post(response, sleep=0):
@staticmethod
async def dummy_post(url, headers, data):
await asyncio.sleep(sleep)
return DummyResponse(
content=next(response) if isinstance(response, Iterator) else response
)
return dummy_post


class SemanticSearchTest(unittest.TestCase):

def setUp(self):
self.semantic_search = SemanticSearch()
self.corpus = Corpus.from_file('deerwester')

def tearDown(self):
self.semantic_search.clear_cache()

def test_make_chunks_small(self):
chunks = self.semantic_search._make_chunks(
self.corpus.documents, [100] * len(self.corpus.documents)
)
self.assertEqual(len(chunks), min(len(self.corpus.documents), MIN_CHUNKS))

def test_make_chunks_medium(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS
chunks = self.semantic_search._make_chunks(
documents, [MAX_PACKAGE_SIZE / MIN_CHUNKS - 1] * len(documents)
)
self.assertEqual(len(chunks), MIN_CHUNKS)

def test_make_chunks_large(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS * 100
mps = MAX_PACKAGE_SIZE
chunks = self.semantic_search._make_chunks(
documents,
[mps / 100] * (len(documents) - 2) + [0.3 * mps, 0.9 * mps, mps]
)
self.assertGreater(len(chunks), MIN_CHUNKS)

@patch(PATCH_METHOD)
def test_empty_corpus(self, mock):
self.assertEqual(
len(self.semantic_search(self.corpus.documents[:0], QUERIES)), 0
)
mock.request.assert_not_called()
mock.get_response.assert_not_called()
self.assertEqual(
self.semantic_search._server_communicator._cache._cache_dict,
dict()
)

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
def test_success(self):
result = self.semantic_search(self.corpus.documents, QUERIES)
self.assertEqual(result, IDEAL_RESPONSE)

@patch(PATCH_METHOD, make_dummy_post(RESPONSE[0]))
def test_success_chunks(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS
result = self.semantic_search(documents, QUERIES)
self.assertEqual(len(result), MIN_CHUNKS)