Skip to content

Commit

Permalink
Merge pull request #717 from djukicn/semantic-search
Browse files Browse the repository at this point in the history
Semantic search server communicator
  • Loading branch information
PrimozGodec authored Nov 3, 2021
2 parents 68941e1 + 15ed20a commit b99a1c8
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 0 deletions.
140 changes: 140 additions & 0 deletions orangecontrib/text/semantic_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import json
import base64
import zlib
import sys
import re
from typing import Any, Optional, List, Optional, Tuple, Callable

import numpy as np

from Orange.misc.server_embedder import ServerEmbedderCommunicator
from orangecontrib.text import Corpus
from orangecontrib.text.misc import wait_nltk_data


MAX_PACKAGE_SIZE = 50000
MIN_CHUNKS = 20


class SemanticSearch:

def __init__(self) -> None:
self._server_communicator = _ServerCommunicator(
model_name='semantic-search',
max_parallel_requests=100,
server_url='https://apiv2.garaza.io',
embedder_type='text',
)

def __call__(
self, texts: List[str], queries: List[str],
progress_callback: Optional[Callable] = None
) -> List[Optional[List[Tuple[Tuple[int, int], float]]]]:
"""Computes matches for given documents and queries.
Parameters
----------
texts: List[str]
A list of raw texts to be matched.
queries: List[str]
A list of query words/phrases.
Returns
-------
List[Optional[List[Tuple[Tuple[int, int], float]]]]
The elements of the outer list represent each document. The entries
are either None or lists of matches. Entries of each list of matches
are matches for each sentence. Each match is of the form
((start_idx, end_idx), score). Note that tuples are actually
converted to lists before the result is returned.
"""

if len(texts) == 0 or len(queries) == 0:
return [None] * len(texts)

chunks = list()
chunk = list()
skipped = list()
queries_enc = base64.b64encode(
zlib.compress(
json.dumps(queries).encode('utf-8', 'replace'),
level=-1
)
).decode('utf-8', 'replace')

encoded_texts = list()
sizes = list()
chunks = list()
for i, text in enumerate(texts):
encoded = base64.b64encode(zlib.compress(
text.encode('utf-8', 'replace'), level=-1)
).decode('utf-8', 'replace')
size = sys.getsizeof(encoded)
if size > MAX_PACKAGE_SIZE:
skipped.append(i)
continue
encoded_texts.append(encoded)
sizes.append(size)

chunks_ = self._make_chunks(encoded_texts, sizes)
for chunk in chunks_:
chunks.append([chunk, queries_enc])

result_ = self._server_communicator.embedd_data(
chunks, processed_callback=progress_callback,
)
if result_ is None:
return [None] * len(texts)

result = list()
for chunk in result_:
result.extend(chunk)

results = list()
idx = 0
for i in range(len(texts)):
if i in skipped:
results.append(None)
else:
results.append(result[idx])
idx += 1

return results

def _make_chunks(self, encoded_texts, sizes, depth=0):
chunks = np.array_split(encoded_texts, MIN_CHUNKS if depth == 0 else 2)
chunk_sizes = np.array_split(sizes, MIN_CHUNKS if depth == 0 else 2)
result = list()
for i in range(len(chunks)):
if np.sum(chunk_sizes[i]) > MAX_PACKAGE_SIZE:
result.extend(self._make_chunks(chunks[i], chunk_sizes[i], depth + 1))
else:
result.append(chunks[i])
return [list(r) for r in result if len(r) > 0]

def set_cancelled(self):
if hasattr(self, '_server_communicator'):
self._server_communicator.set_cancelled()

def clear_cache(self):
if self._server_communicator:
self._server_communicator.clear_cache()

def __enter__(self):
return self

def __exit__(self, ex_type, value, traceback):
self.set_cancelled()

def __del__(self):
self.__exit__(None, None, None)


class _ServerCommunicator(ServerEmbedderCommunicator):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.content_type = 'application/json'

async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
return json.dumps(data_instance).encode('utf-8', 'replace')
117 changes: 117 additions & 0 deletions orangecontrib/text/tests/test_semantic_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import unittest
from unittest.mock import patch
from collections.abc import Iterator
import asyncio
import numpy as np

from orangecontrib.text.semantic_search import (
SemanticSearch,
MIN_CHUNKS,
MAX_PACKAGE_SIZE
)
from orangecontrib.text import Corpus

PATCH_METHOD = 'httpx.AsyncClient.post'
QUERIES = ['test query', 'another test query']
RESPONSE = [
b'{ "embedding": [[[[0, 57], 0.22114424407482147]]] }',
b'{ "embedding": [[[[0, 57], 0.5597518086433411]]] }',
b'{ "embedding": [[[[0, 40], 0.11774948984384537]]] }',
b'{ "embedding": [[[[0, 50], 0.2228381633758545]]] }',
b'{ "embedding": [[[[0, 61], 0.19825558364391327]]] }',
b'{ "embedding": [[[[0, 47], 0.19025272130966187]]] }',
b'{ "embedding": [[[[0, 40], 0.09688498824834824]]] }',
b'{ "embedding": [[[[0, 55], 0.2982504367828369]]] }',
b'{ "embedding": [[[[0, 12], 0.2982504367828369]]] }',
]
IDEAL_RESPONSE = [
[[[0, 57], 0.22114424407482147]],
[[[0, 57], 0.5597518086433411]],
[[[0, 40], 0.11774948984384537]],
[[[0, 50], 0.2228381633758545]],
[[[0, 61], 0.19825558364391327]],
[[[0, 47], 0.19025272130966187]],
[[[0, 40], 0.09688498824834824]],
[[[0, 55], 0.2982504367828369]],
[[[0, 12], 0.2982504367828369]]
]


class DummyResponse:

def __init__(self, content):
self.content = content


def make_dummy_post(response, sleep=0):
@staticmethod
async def dummy_post(url, headers, data):
await asyncio.sleep(sleep)
return DummyResponse(
content=next(response) if isinstance(response, Iterator) else response
)
return dummy_post


class SemanticSearchTest(unittest.TestCase):

def setUp(self):
self.semantic_search = SemanticSearch()
self.corpus = Corpus.from_file('deerwester')

def tearDown(self):
self.semantic_search.clear_cache()

def test_make_chunks_small(self):
chunks = self.semantic_search._make_chunks(
self.corpus.documents, [100] * len(self.corpus.documents)
)
self.assertEqual(len(chunks), min(len(self.corpus.documents), MIN_CHUNKS))

def test_make_chunks_medium(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS
chunks = self.semantic_search._make_chunks(
documents, [MAX_PACKAGE_SIZE / MIN_CHUNKS - 1] * len(documents)
)
self.assertEqual(len(chunks), MIN_CHUNKS)

def test_make_chunks_large(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS * 100
mps = MAX_PACKAGE_SIZE
chunks = self.semantic_search._make_chunks(
documents,
[mps / 100] * (len(documents) - 2) + [0.3 * mps, 0.9 * mps, mps]
)
self.assertGreater(len(chunks), MIN_CHUNKS)

@patch(PATCH_METHOD)
def test_empty_corpus(self, mock):
self.assertEqual(
len(self.semantic_search(self.corpus.documents[:0], QUERIES)), 0
)
mock.request.assert_not_called()
mock.get_response.assert_not_called()
self.assertEqual(
self.semantic_search._server_communicator._cache._cache_dict,
dict()
)

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
def test_success(self):
result = self.semantic_search(self.corpus.documents, QUERIES)
self.assertEqual(result, IDEAL_RESPONSE)

@patch(PATCH_METHOD, make_dummy_post(RESPONSE[0]))
def test_success_chunks(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS
result = self.semantic_search(documents, QUERIES)
self.assertEqual(len(result), MIN_CHUNKS)

0 comments on commit b99a1c8

Please sign in to comment.