diff --git a/orangecontrib/text/semantic_search.py b/orangecontrib/text/semantic_search.py new file mode 100644 index 000000000..cbdf0ce02 --- /dev/null +++ b/orangecontrib/text/semantic_search.py @@ -0,0 +1,140 @@ +import json +import base64 +import zlib +import sys +import re +from typing import Any, Optional, List, Optional, Tuple, Callable + +import numpy as np + +from Orange.misc.server_embedder import ServerEmbedderCommunicator +from orangecontrib.text import Corpus +from orangecontrib.text.misc import wait_nltk_data + + +MAX_PACKAGE_SIZE = 50000 +MIN_CHUNKS = 20 + + +class SemanticSearch: + + def __init__(self) -> None: + self._server_communicator = _ServerCommunicator( + model_name='semantic-search', + max_parallel_requests=100, + server_url='https://apiv2.garaza.io', + embedder_type='text', + ) + + def __call__( + self, texts: List[str], queries: List[str], + progress_callback: Optional[Callable] = None + ) -> List[Optional[List[Tuple[Tuple[int, int], float]]]]: + """Computes matches for given documents and queries. + + Parameters + ---------- + texts: List[str] + A list of raw texts to be matched. + queries: List[str] + A list of query words/phrases. + + Returns + ------- + List[Optional[List[Tuple[Tuple[int, int], float]]]] + The elements of the outer list represent each document. The entries + are either None or lists of matches. Entries of each list of matches + are matches for each sentence. Each match is of the form + ((start_idx, end_idx), score). Note that tuples are actually + converted to lists before the result is returned. + """ + + if len(texts) == 0 or len(queries) == 0: + return [None] * len(texts) + + chunks = list() + chunk = list() + skipped = list() + queries_enc = base64.b64encode( + zlib.compress( + json.dumps(queries).encode('utf-8', 'replace'), + level=-1 + ) + ).decode('utf-8', 'replace') + + encoded_texts = list() + sizes = list() + chunks = list() + for i, text in enumerate(texts): + encoded = base64.b64encode(zlib.compress( + text.encode('utf-8', 'replace'), level=-1) + ).decode('utf-8', 'replace') + size = sys.getsizeof(encoded) + if size > MAX_PACKAGE_SIZE: + skipped.append(i) + continue + encoded_texts.append(encoded) + sizes.append(size) + + chunks_ = self._make_chunks(encoded_texts, sizes) + for chunk in chunks_: + chunks.append([chunk, queries_enc]) + + result_ = self._server_communicator.embedd_data( + chunks, processed_callback=progress_callback, + ) + if result_ is None: + return [None] * len(texts) + + result = list() + for chunk in result_: + result.extend(chunk) + + results = list() + idx = 0 + for i in range(len(texts)): + if i in skipped: + results.append(None) + else: + results.append(result[idx]) + idx += 1 + + return results + + def _make_chunks(self, encoded_texts, sizes, depth=0): + chunks = np.array_split(encoded_texts, MIN_CHUNKS if depth == 0 else 2) + chunk_sizes = np.array_split(sizes, MIN_CHUNKS if depth == 0 else 2) + result = list() + for i in range(len(chunks)): + if np.sum(chunk_sizes[i]) > MAX_PACKAGE_SIZE: + result.extend(self._make_chunks(chunks[i], chunk_sizes[i], depth + 1)) + else: + result.append(chunks[i]) + return [list(r) for r in result if len(r) > 0] + + def set_cancelled(self): + if hasattr(self, '_server_communicator'): + self._server_communicator.set_cancelled() + + def clear_cache(self): + if self._server_communicator: + self._server_communicator.clear_cache() + + def __enter__(self): + return self + + def __exit__(self, ex_type, value, traceback): + self.set_cancelled() + + def __del__(self): + self.__exit__(None, None, None) + + +class _ServerCommunicator(ServerEmbedderCommunicator): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.content_type = 'application/json' + + async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]: + return json.dumps(data_instance).encode('utf-8', 'replace') diff --git a/orangecontrib/text/tests/test_semantic_search.py b/orangecontrib/text/tests/test_semantic_search.py new file mode 100644 index 000000000..736e75fe8 --- /dev/null +++ b/orangecontrib/text/tests/test_semantic_search.py @@ -0,0 +1,117 @@ +import unittest +from unittest.mock import patch +from collections.abc import Iterator +import asyncio +import numpy as np + +from orangecontrib.text.semantic_search import ( + SemanticSearch, + MIN_CHUNKS, + MAX_PACKAGE_SIZE +) +from orangecontrib.text import Corpus + +PATCH_METHOD = 'httpx.AsyncClient.post' +QUERIES = ['test query', 'another test query'] +RESPONSE = [ + b'{ "embedding": [[[[0, 57], 0.22114424407482147]]] }', + b'{ "embedding": [[[[0, 57], 0.5597518086433411]]] }', + b'{ "embedding": [[[[0, 40], 0.11774948984384537]]] }', + b'{ "embedding": [[[[0, 50], 0.2228381633758545]]] }', + b'{ "embedding": [[[[0, 61], 0.19825558364391327]]] }', + b'{ "embedding": [[[[0, 47], 0.19025272130966187]]] }', + b'{ "embedding": [[[[0, 40], 0.09688498824834824]]] }', + b'{ "embedding": [[[[0, 55], 0.2982504367828369]]] }', + b'{ "embedding": [[[[0, 12], 0.2982504367828369]]] }', +] +IDEAL_RESPONSE = [ + [[[0, 57], 0.22114424407482147]], + [[[0, 57], 0.5597518086433411]], + [[[0, 40], 0.11774948984384537]], + [[[0, 50], 0.2228381633758545]], + [[[0, 61], 0.19825558364391327]], + [[[0, 47], 0.19025272130966187]], + [[[0, 40], 0.09688498824834824]], + [[[0, 55], 0.2982504367828369]], + [[[0, 12], 0.2982504367828369]] +] + + +class DummyResponse: + + def __init__(self, content): + self.content = content + + +def make_dummy_post(response, sleep=0): + @staticmethod + async def dummy_post(url, headers, data): + await asyncio.sleep(sleep) + return DummyResponse( + content=next(response) if isinstance(response, Iterator) else response + ) + return dummy_post + + +class SemanticSearchTest(unittest.TestCase): + + def setUp(self): + self.semantic_search = SemanticSearch() + self.corpus = Corpus.from_file('deerwester') + + def tearDown(self): + self.semantic_search.clear_cache() + + def test_make_chunks_small(self): + chunks = self.semantic_search._make_chunks( + self.corpus.documents, [100] * len(self.corpus.documents) + ) + self.assertEqual(len(chunks), min(len(self.corpus.documents), MIN_CHUNKS)) + + def test_make_chunks_medium(self): + num_docs = len(self.corpus.documents) + documents = self.corpus.documents + if num_docs < MIN_CHUNKS: + documents = [documents[0]] * MIN_CHUNKS + chunks = self.semantic_search._make_chunks( + documents, [MAX_PACKAGE_SIZE / MIN_CHUNKS - 1] * len(documents) + ) + self.assertEqual(len(chunks), MIN_CHUNKS) + + def test_make_chunks_large(self): + num_docs = len(self.corpus.documents) + documents = self.corpus.documents + if num_docs < MIN_CHUNKS: + documents = [documents[0]] * MIN_CHUNKS * 100 + mps = MAX_PACKAGE_SIZE + chunks = self.semantic_search._make_chunks( + documents, + [mps / 100] * (len(documents) - 2) + [0.3 * mps, 0.9 * mps, mps] + ) + self.assertGreater(len(chunks), MIN_CHUNKS) + + @patch(PATCH_METHOD) + def test_empty_corpus(self, mock): + self.assertEqual( + len(self.semantic_search(self.corpus.documents[:0], QUERIES)), 0 + ) + mock.request.assert_not_called() + mock.get_response.assert_not_called() + self.assertEqual( + self.semantic_search._server_communicator._cache._cache_dict, + dict() + ) + + @patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE))) + def test_success(self): + result = self.semantic_search(self.corpus.documents, QUERIES) + self.assertEqual(result, IDEAL_RESPONSE) + + @patch(PATCH_METHOD, make_dummy_post(RESPONSE[0])) + def test_success_chunks(self): + num_docs = len(self.corpus.documents) + documents = self.corpus.documents + if num_docs < MIN_CHUNKS: + documents = [documents[0]] * MIN_CHUNKS + result = self.semantic_search(documents, QUERIES) + self.assertEqual(len(result), MIN_CHUNKS)