Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refresh on interval function refreshes index_info if enough time has … #333

Merged
merged 4 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/marqo/tensor_search/index_meta_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,27 @@
import asyncio
import datetime
import time
import traceback
from multiprocessing import Process, Manager
from marqo.tensor_search.models.index_info import IndexInfo
from typing import Dict
from marqo import errors
from marqo.tensor_search import backend
from marqo.config import Config
from marqo.tensor_search.tensor_search_logging import get_logger

logger = get_logger(__name__)


index_info_cache = dict()

# the following is a non thread safe dict. Its purpose to be used by request
# threads to calculate whether to refresh an index's cached index_info.
# Because it is non thread safe, there is a chance multiple threads push out
# multiple refresh requests at the same. It isn't a critical problem if that
# happens.
index_last_refreshed_time = dict()


def empty_cache():
global index_info_cache
Expand Down Expand Up @@ -48,6 +60,43 @@ def get_cache() -> Dict[str, IndexInfo]:
return index_info_cache


def refresh_index_info_on_interval(config: Config, index_name: str, interval_seconds: int) -> None:
"""Refreshes an index's index_info if inteval_seconds have elapsed since the last time it was refreshed

Non-thread safe, so there is a chance two threads both refresh index_info at the same time.
"""
try:
last_refreshed_time = index_last_refreshed_time[index_name]
except KeyError:
last_refreshed_time = datetime.datetime.min

now = datetime.datetime.now()

interval_as_time_delta = datetime.timedelta(seconds=interval_seconds)
if now - last_refreshed_time >= interval_as_time_delta:
# We assume that we will successfully refresh index info. We set the time to now ()
# to lower the chance of other threads simultaneously refreshing the cache
index_last_refreshed_time[index_name] = now
try:
backend.get_index_info(config=config, index_name=index_name)

# If we get any errors, we set the index's last refreshed time to what we originally found
# This lets another thread come in and update it. There is a chance that, in the mean time
except (errors.IndexNotFoundError, errors.NonTensorIndexError):
# trying to refresh the index, and not finding any tensor index is considered a
# successful of the index.
pass
except Exception as e2:
# any other exception is problematic. We reset the index to the last_refreshed_time to
# let another thread refresh the index's index_info
index_last_refreshed_time[index_name] = last_refreshed_time
logger.warning("refresh_index_info_on_interval(): error during background index_info refresh. Reason:"
f"\n{e2}")
logger.debug("refresh_index_info_on_interval(): error during background index_info refresh. "
f"Traceback: \n{traceback.print_stack()}")
raise e2


def refresh_index(config: Config, index_name: str) -> IndexInfo:
"""function to update an index, from the cluster.

Expand Down
5 changes: 3 additions & 2 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,10 +867,11 @@ def search(config: Config, index_name: str, text: Union[str, dict],
if index_name not in index_meta_cache.get_cache():
backend.get_index_info(config=config, index_name=index_name)

REFRESH_INTERVAL_SECONDS = 2
# update cache in the background
cache_update_thread = threading.Thread(
target=index_meta_cache.refresh_index,
args=(config, index_name))
target=index_meta_cache.refresh_index_info_on_interval,
args=(config, index_name, REFRESH_INTERVAL_SECONDS))
cache_update_thread.start()

if search_method.upper() == SearchMethod.TENSOR:
Expand Down
229 changes: 228 additions & 1 deletion tests/tensor_search/test_index_meta_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import copy
import datetime
import pprint
import threading
import time
import unittest
import requests
from marqo.tensor_search.models.index_info import IndexInfo
from marqo.tensor_search.models import index_info
from marqo.tensor_search import tensor_search
Expand All @@ -12,6 +15,8 @@
from marqo.tensor_search.enums import TensorField, SearchMethod, IndexSettingsField
from marqo.tensor_search import configs
from tests.marqo_test import MarqoTestCase
from unittest import mock
from marqo import errors


class TestIndexMetaCache(MarqoTestCase):
Expand Down Expand Up @@ -199,7 +204,8 @@ def test_search_lexical_externally_created_field(self):
index_name=self.index_name_1, config=self.config, text="a line of text",
return_doc_ids=True, search_method=SearchMethod.LEXICAL)
assert len(result["hits"]) == 0
time.sleep(1)
# REFRESH INTERVAL IS 2 seconds
time.sleep(4)
result_2 = tensor_search.search(
index_name=self.index_name_1, config=self.config, text="a line of text",
return_doc_ids=True, search_method=SearchMethod.LEXICAL)
Expand Down Expand Up @@ -349,3 +355,224 @@ def test_index_settings_after_cache_refresh(self):
index_meta_cache.refresh_index(config=self.config, index_name=self.index_name_1)
ix_refreshed_info = index_meta_cache.get_index_info(config=self.config, index_name=self.index_name_1)
assert ix_refreshed_info.index_settings == expected_index_settings

def test_index_refresh_on_interval_multi_threaded(self):
""" This test involves spinning up 5 threads or so. these threads
try to refresh the cache every 0.1 seconds. Despite this, the
last_refresh_time ensures we only actually push out a mappings
request once per second.
Because checking the last_refresh_time isn't threadsafe, this
test may occasionally fail. Enabling log output, and allowing
more log output increases risk of test failure. However, most
the time it should pass.

"""
mock_get = mock.MagicMock()
@mock.patch('marqo._httprequests.ALLOWED_OPERATIONS', {mock_get})
@mock.patch('requests.get', mock_get)
def run():
N_seconds = 3
REFRESH_INTERVAL_SECONDS = 1
start_time = datetime.datetime.now()
num_threads = 5
total_loops = [0] * num_threads
sleep_time = 0.1

def threaded_while(thread_num, loop_record):
thread_loops = 0
while datetime.datetime.now() - start_time < datetime.timedelta(seconds=N_seconds):
cache_update_thread = threading.Thread(
target=index_meta_cache.refresh_index_info_on_interval,
args=(self.config, self.index_name_1, REFRESH_INTERVAL_SECONDS))
cache_update_thread.start()
time.sleep(sleep_time)
thread_loops += 1
loop_record[thread_num] = thread_loops

threads = [threading.Thread(target=threaded_while, args=(i, total_loops)) for i in range(num_threads)]
for th in threads:
th.start()

for th in threads:
th.join()
estimated_loops = round((N_seconds/sleep_time) * num_threads)
assert sum(total_loops) in range(estimated_loops - num_threads, estimated_loops + 1)
time.sleep(0.5) # let remaining thread complete, if needed

assert mock_get.call_count == N_seconds
return True
assert run()

def test_index_refresh_on_interval_multi_threaded_no_index(self):
""" If we encounter NonTensorIndexError/ IndexNotExists error
while refreshing the index info, it is considered a successful
refresh and the refresh happens on the intervals as usual.

isn't threadsafe, this test may occasionally fail.

"""
mock_get = mock.MagicMock()
mock_response = requests.Response()
mock_response.status_code = 200
mock_response.json = lambda: '{"a":"b"}'

# mock_get.return_value = mock_response
@mock.patch('marqo._httprequests.ALLOWED_OPERATIONS', {mock_get})
@mock.patch('requests.get', mock_get)
def run(error):
def use_error(*args, **kwargs):
raise error('')
mock_get.side_effect = use_error

N_seconds = 3
REFRESH_INTERVAL_SECONDS = 1
start_time = datetime.datetime.now()
num_threads = 5
total_loops = [0] * num_threads
sleep_time = 0.1

def threaded_while(thread_num, loop_record):
thread_loops = 0
while datetime.datetime.now() - start_time < datetime.timedelta(seconds=N_seconds):
cache_update_thread = threading.Thread(
target=index_meta_cache.refresh_index_info_on_interval,
args=(self.config, self.index_name_1, REFRESH_INTERVAL_SECONDS))
cache_update_thread.start()
time.sleep(sleep_time)
thread_loops += 1
loop_record[thread_num] = thread_loops

threads = [threading.Thread(target=threaded_while, args=(i, total_loops)) for i in range(num_threads)]
for th in threads:
th.start()

for th in threads:
th.join()
estimated_loops = round((N_seconds/sleep_time) * num_threads)
assert sum(total_loops) in range(estimated_loops - num_threads, estimated_loops + 1)
time.sleep(0.5) # let remaining thread complete, if needed
assert mock_get.call_count == N_seconds
return True
assert run(error=errors.NonTensorIndexError)
mock_get.reset_mock()
assert run(error=errors.IndexNotFoundError)

def test_index_refresh_on_interval_multi_threaded_errors(self):
""" If we encounter any error besides
NonTensorIndexError/ IndexNotExists we this is considered a
failed refresh, which doesn't prevent other threads from
trying to update it.

This is not threadsafe and may occassionally fail

"""
mock_get = mock.MagicMock()
mock_response = requests.Response()
mock_response.status_code = 200
mock_response.json = lambda: '{"a":"b"}'

# mock_get.return_value = mock_response
@mock.patch('marqo._httprequests.ALLOWED_OPERATIONS', {mock_get})
@mock.patch('requests.get', mock_get)
def run(error):
def use_error(*args, **kwargs):
raise error('')

mock_get.side_effect = use_error

N_seconds = 3
REFRESH_INTERVAL_SECONDS = 1
start_time = datetime.datetime.now()
num_threads = 5
total_loops = [0] * num_threads
sleep_time = 0.1

def threaded_while(thread_num, loop_record):
thread_loops = 0
while datetime.datetime.now() - start_time < datetime.timedelta(seconds=N_seconds):
cache_update_thread = threading.Thread(
target=index_meta_cache.refresh_index_info_on_interval,
args=(self.config, self.index_name_1, REFRESH_INTERVAL_SECONDS))
cache_update_thread.start()
time.sleep(sleep_time)
thread_loops += 1
loop_record[thread_num] = thread_loops

threads = [threading.Thread(target=threaded_while, args=(i, total_loops)) for i in range(num_threads)]
for th in threads:
th.start()

for th in threads:
th.join()
estimated_loops = round((N_seconds / sleep_time) * num_threads)
assert sum(total_loops) in range(estimated_loops - num_threads, estimated_loops + 1)
time.sleep(0.5) # let remaining thread complete, if needed
# because we get these failures we set the last_refresh_time back to the original
# allowing other threads to refresh index_info
assert mock_get.call_count in range(estimated_loops - num_threads, estimated_loops + 1)
return True

assert run(error=ValueError)
mock_get.reset_mock()
assert run(error=requests.ConnectionError)

def test_search_index_refresh_on_interval_multi_threaded(self):
""" Same as test_index_refresh_on_interval_multi_threaded() ,
but using the search endpoint.

The same caveat applies: Because checking the last_refresh_time
isn't threadsafe, this test may occasionally fail.
"""

mock_get = mock.Mock()
mock_response = requests.Response()
mock_response.status_code = 200
mock_response.json = lambda: '{"a":"b"}'
mock_get.return_value = mock_response

# we need to search it once, to to get something in the cache, otherwise
# the threads will see an empty cache and try to fill it
try:
tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[{"hi": "hello"}],
auto_refresh=False)
except IndexNotFoundError:
pass
@mock.patch('marqo._httprequests.ALLOWED_OPERATIONS', {mock_get})
@mock.patch('marqo._httprequests.requests.get', mock_get)
def run():

# requests.get('23456')
N_seconds = 4
# the following is hard coded in search()
REFRESH_INTERVAL_SECONDS = 2
start_time = datetime.datetime.now()
num_threads = 5
total_loops = [0] * num_threads
sleep_time = 0.1

def threaded_while(thread_num, loop_record):
thread_loops = 0
while datetime.datetime.now() - start_time < datetime.timedelta(seconds=N_seconds):
cache_update_thread = threading.Thread(
target=tensor_search.search,
kwargs={"config": self.config, "index_name": self.index_name_1, "text": "hello" })
cache_update_thread.start()
time.sleep(sleep_time)
thread_loops += 1
loop_record[thread_num] = thread_loops
threads = [threading.Thread(target=threaded_while, args=(i, total_loops)) for i in range(num_threads)]
for th in threads:
th.start()
for th in threads:
th.join()

estimated_loops = round((N_seconds/sleep_time) * num_threads)
assert sum(total_loops) in range(estimated_loops - num_threads, estimated_loops + 1)
time.sleep(0.5) # let remaining thread complete, if needed
mappings_call_count = len([c for c in mock_get.mock_calls if '_mapping' in str(c)])
# for the refresh interal hardcoded in search(), which is 2 seconds, we expect a total
# of only 2 calls to the mappings endpoint, even though there are a lot more search requests
assert mappings_call_count == round(N_seconds/REFRESH_INTERVAL_SECONDS)
return True
assert run()