Skip to content

Commit

Permalink
fix: Random port allocation for python server in tests (feast-dev#2710)
Browse files Browse the repository at this point in the history
Signed-off-by: Oleksii Moskalenko <moskalenko.alexey@gmail.com>
  • Loading branch information
pyalex authored May 16, 2022
1 parent 7a043eb commit dee8090
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 31 deletions.
27 changes: 18 additions & 9 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,19 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):


@pytest.fixture(scope="session")
def python_server(environment):
assert not _check_port_open("localhost", environment.get_local_server_port())
def feature_server_endpoint(environment):
if (
not environment.python_feature_server
or environment.test_repo_config.provider != "local"
):
yield environment.feature_store.get_feature_server_endpoint()
return

port = _free_port()

proc = Process(
target=start_test_local_server,
args=(environment.feature_store.repo_path, environment.get_local_server_port()),
args=(environment.feature_store.repo_path, port),
)
if (
environment.python_feature_server
Expand All @@ -287,14 +294,10 @@ def python_server(environment):
proc.start()
# Wait for server to start
wait_retry_backoff(
lambda: (
None,
_check_port_open("localhost", environment.get_local_server_port()),
),
timeout_secs=10,
lambda: (None, _check_port_open("localhost", port)), timeout_secs=10,
)

yield
yield f"http://localhost:{port}"

if proc.is_alive():
proc.kill()
Expand All @@ -314,6 +317,12 @@ def _check_port_open(host, port) -> bool:
return sock.connect_ex((host, port)) == 0


def _free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]


@pytest.fixture(scope="session")
def universal_data_sources(environment) -> TestData:
return construct_universal_test_data(environment)
Expand Down
20 changes: 0 additions & 20 deletions sdk/python/tests/integration/feature_repos/repo_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import importlib
import json
import os
import re
import tempfile
import uuid
from dataclasses import dataclass
Expand Down Expand Up @@ -328,29 +327,10 @@ class Environment:
worker_id: str
online_store_creator: Optional[OnlineStoreCreator] = None

next_id = 0

def __post_init__(self):
self.end_date = datetime.utcnow().replace(microsecond=0, second=0, minute=0)
self.start_date: datetime = self.end_date - timedelta(days=3)

Environment.next_id += 1
self.id = Environment.next_id

def get_feature_server_endpoint(self) -> str:
if self.python_feature_server and self.test_repo_config.provider == "local":
return f"http://localhost:{self.get_local_server_port()}"
return self.feature_store.get_feature_server_endpoint()

def get_local_server_port(self) -> int:
# Heuristic when running with xdist to extract unique ports for each worker
parsed_worker_id = re.findall("gw(\\d+)", self.worker_id)
if len(parsed_worker_id) != 0:
worker_id_num = int(parsed_worker_id[0])
else:
worker_id_num = 0
return 6000 + 100 * worker_id_num + self.id


def table_name_from_data_source(ds: DataSource) -> Optional[str]:
if hasattr(ds, "table_ref"):
Expand Down
15 changes: 13 additions & 2 deletions sdk/python/tests/integration/online_store/test_universal_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def _get_online_features_dict_remotely(

def get_online_features_dict(
environment: Environment,
endpoint: str,
features: Union[List[str], FeatureService],
entity_rows: List[Dict[str, Any]],
full_feature_names: bool = False,
Expand All @@ -305,7 +306,6 @@ def get_online_features_dict(
assertpy.assert_that(online_features).is_not_none()
dict1 = online_features.to_dict()

endpoint = environment.get_feature_server_endpoint()
# If endpoint is None, it means that a local / remote feature server aren't configured
if endpoint is not None:
dict2 = _get_online_features_dict_remotely(
Expand Down Expand Up @@ -447,7 +447,7 @@ def test_online_retrieval_with_event_timestamps(
@pytest.mark.goserver
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
def test_online_retrieval(
environment, universal_data_sources, python_server, full_feature_names
environment, universal_data_sources, feature_server_endpoint, full_feature_names
):
fs = environment.feature_store
entities, datasets, data_sources = universal_data_sources
Expand Down Expand Up @@ -547,6 +547,7 @@ def test_online_retrieval(

online_features_dict = get_online_features_dict(
environment=environment,
endpoint=feature_server_endpoint,
features=feature_refs,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
Expand All @@ -556,6 +557,7 @@ def test_online_retrieval(
# feature isn't requested.
online_features_no_conv_rate = get_online_features_dict(
environment=environment,
endpoint=feature_server_endpoint,
features=[ref for ref in feature_refs if ref != "driver_stats:conv_rate"],
entity_rows=entity_rows,
full_feature_names=full_feature_names,
Expand Down Expand Up @@ -616,6 +618,7 @@ def test_online_retrieval(
# Check what happens for missing values
missing_responses_dict = get_online_features_dict(
environment=environment,
endpoint=feature_server_endpoint,
features=feature_refs,
entity_rows=[{"driver_id": 0, "customer_id": 0, "val_to_add": 100}],
full_feature_names=full_feature_names,
Expand All @@ -635,13 +638,15 @@ def test_online_retrieval(
with pytest.raises(RequestDataNotFoundInEntityRowsException):
get_online_features_dict(
environment=environment,
endpoint=feature_server_endpoint,
features=feature_refs,
entity_rows=[{"driver_id": 0, "customer_id": 0}],
full_feature_names=full_feature_names,
)

assert_feature_service_correctness(
environment,
feature_server_endpoint,
feature_service,
entity_rows,
full_feature_names,
Expand All @@ -659,6 +664,7 @@ def test_online_retrieval(
]
assert_feature_service_entity_mapping_correctness(
environment,
feature_server_endpoint,
feature_service_entity_mapping,
entity_rows,
full_feature_names,
Expand Down Expand Up @@ -856,6 +862,7 @@ def get_latest_feature_values_for_location_df(entity_row, origin_df, destination

def assert_feature_service_correctness(
environment,
endpoint,
feature_service,
entity_rows,
full_feature_names,
Expand All @@ -866,6 +873,7 @@ def assert_feature_service_correctness(
):
feature_service_online_features_dict = get_online_features_dict(
environment=environment,
endpoint=endpoint,
features=feature_service,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
Expand Down Expand Up @@ -905,6 +913,7 @@ def assert_feature_service_correctness(

def assert_feature_service_entity_mapping_correctness(
environment,
endpoint,
feature_service,
entity_rows,
full_feature_names,
Expand All @@ -914,6 +923,7 @@ def assert_feature_service_entity_mapping_correctness(
if full_feature_names:
feature_service_online_features_dict = get_online_features_dict(
environment=environment,
endpoint=endpoint,
features=feature_service,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
Expand Down Expand Up @@ -948,6 +958,7 @@ def assert_feature_service_entity_mapping_correctness(
with pytest.raises(FeatureNameCollisionError):
get_online_features_dict(
environment=environment,
endpoint=endpoint,
features=feature_service,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
Expand Down

0 comments on commit dee8090

Please sign in to comment.