Skip to content

Commit

Permalink
ci: remove unnecessary fixtures from integration tests (#4544)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Sep 21, 2024
1 parent 745cba7 commit 145b850
Show file tree
Hide file tree
Showing 11 changed files with 1,370 additions and 1,313 deletions.
Empty file added integration_tests/__init__.py
Empty file.
759 changes: 759 additions & 0 deletions integration_tests/_helpers.py

Large diffs are not rendered by default.

Empty file.
503 changes: 14 additions & 489 deletions integration_tests/auth/conftest.py

Large diffs are not rendered by default.

1,007 changes: 462 additions & 545 deletions integration_tests/auth/test_auth.py

Large diffs are not rendered by default.

332 changes: 110 additions & 222 deletions integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,76 @@
import os
import sys
from contextlib import ExitStack, contextmanager
from functools import partial
from subprocess import PIPE, STDOUT
from threading import Lock, Thread
from time import sleep, time
from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Protocol, cast
from contextlib import ExitStack
from dataclasses import asdict
from itertools import count, starmap
from typing import Generator, Iterator, Optional, Tuple, cast
from unittest import mock
from urllib.parse import urljoin
from urllib.request import urlopen

import httpx
import pytest
from _pytest.fixtures import SubRequest
from _pytest.tmpdir import TempPathFactory
from faker import Faker
from openinference.semconv.resource import ResourceAttributes
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter
from opentelemetry.trace import Span, Tracer
from phoenix.auth import REQUIREMENTS_FOR_PHOENIX_SECRET
from phoenix.config import (
ENV_PHOENIX_GRPC_PORT,
ENV_PHOENIX_PORT,
ENV_PHOENIX_SQL_DATABASE_SCHEMA,
ENV_PHOENIX_SQL_DATABASE_URL,
ENV_PHOENIX_WORKING_DIR,
get_base_url,
get_env_database_connection_str,
get_env_database_schema,
get_env_grpc_port,
get_env_host,
)
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
from portpicker import pick_unused_port # type: ignore[import-untyped]
from psutil import STATUS_ZOMBIE, Popen
from sqlalchemy import URL, create_engine, make_url, text
from sqlalchemy.exc import OperationalError
from typing_extensions import TypeAlias

_ProjectName: TypeAlias = str
_SpanName: TypeAlias = str
_Headers: TypeAlias = Dict[str, Any]


class _GetGqlSpans(Protocol):
def __call__(self, *keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]: ...


class _SpanExporterFactory(Protocol):
def __call__(
self,
*,
headers: Optional[_Headers] = None,
) -> SpanExporter: ...
from sqlalchemy import URL, make_url
from typing_extensions import assert_never

from ._helpers import (
_DEFAULT_ADMIN,
_MEMBER,
_Email,
_GetUser,
_grpc_span_exporter,
_http_span_exporter,
_Password,
_Profile,
_random_schema,
_RoleOrUser,
_SpanExporterFactory,
_User,
_UserFactory,
_UserGenerator,
_Username,
)


class _GetTracer(Protocol):
def __call__(
self,
*,
project_name: _ProjectName,
exporter: SpanExporter,
) -> Tracer: ...
@pytest.fixture(
scope="session",
params=[
pytest.param("sqlite:///:memory:", id="sqlite"),
pytest.param(
"postgresql://127.0.0.1:5432/postgres?user=postgres&password=phoenix",
id="postgresql",
),
],
)
def _sql_database_url(request: SubRequest) -> URL:
return make_url(request.param)


class _StartSpan(Protocol):
def __call__(
self,
*,
project_name: _ProjectName,
span_name: _SpanName,
exporter: SpanExporter,
) -> Span: ...
@pytest.fixture(scope="session", params=["http", "grpc"])
def _span_exporter(request: SubRequest) -> _SpanExporterFactory:
if request.param == "http":
return _http_span_exporter
if request.param == "grpc":
return _grpc_span_exporter
raise ValueError(f"Unknown exporter: {request.param}")


@pytest.fixture(scope="class")
def fake() -> Faker:
@pytest.fixture(scope="module")
def _fake() -> Faker:
return Faker()


@pytest.fixture(autouse=True, scope="class")
def env(tmp_path_factory: TempPathFactory) -> Iterator[None]:
@pytest.fixture(autouse=True, scope="module")
def _env(tmp_path_factory: TempPathFactory) -> Iterator[None]:
tmp = tmp_path_factory.getbasetemp()
values = (
(ENV_PHOENIX_PORT, str(pick_unused_port())),
Expand All @@ -91,194 +81,92 @@ def env(tmp_path_factory: TempPathFactory) -> Iterator[None]:
yield


@pytest.fixture(
scope="session",
params=[
pytest.param("sqlite:///:memory:", id="sqlite"),
pytest.param(
"postgresql://127.0.0.1:5432/postgres?user=postgres&password=phoenix",
id="postgresql",
),
],
)
def sql_database_url(request: SubRequest) -> URL:
return make_url(request.param)


@pytest.fixture(autouse=True, scope="class")
def env_phoenix_sql_database_url(
sql_database_url: URL,
fake: Faker,
@pytest.fixture(autouse=True, scope="module")
def _env_phoenix_sql_database_url(
_sql_database_url: URL,
_fake: Faker,
) -> Iterator[None]:
values = [(ENV_PHOENIX_SQL_DATABASE_URL, sql_database_url.render_as_string())]
values = [(ENV_PHOENIX_SQL_DATABASE_URL, _sql_database_url.render_as_string())]
with ExitStack() as stack:
if sql_database_url.get_backend_name().startswith("postgresql"):
schema = stack.enter_context(_random_schema(sql_database_url, fake))
if _sql_database_url.get_backend_name().startswith("postgresql"):
schema = stack.enter_context(_random_schema(_sql_database_url))
values.append((ENV_PHOENIX_SQL_DATABASE_SCHEMA, schema))
stack.enter_context(mock.patch.dict(os.environ, values))
yield


@pytest.fixture
def get_gql_spans(
httpx_client: Callable[[], httpx.Client],
) -> _GetGqlSpans:
def _(*keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]:
out = "name spans{edges{node{" + " ".join(keys) + "}}}"
query = dict(query="query{projects{edges{node{" + out + "}}}}")
resp = httpx_client().post(urljoin(get_base_url(), "graphql"), json=query)
resp.raise_for_status()
resp_dict = resp.json()
assert not resp_dict.get("errors")
return {
project["node"]["name"]: [span["node"] for span in project["node"]["spans"]["edges"]]
for project in resp_dict["data"]["projects"]["edges"]
}
@pytest.fixture(scope="module")
def _emails(_fake: Faker) -> Iterator[_Email]:
return (_fake.unique.email() for _ in count())

return _

@pytest.fixture(scope="module")
def _passwords(_fake: Faker) -> Iterator[_Password]:
return (_fake.unique.password(**asdict(REQUIREMENTS_FOR_PHOENIX_SECRET)) for _ in count())

@pytest.fixture(scope="session")
def http_span_exporter() -> _SpanExporterFactory:
def _(
*,
headers: Optional[_Headers] = None,
) -> SpanExporter:
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter

endpoint = urljoin(get_base_url(), "v1/traces")
exporter = OTLPSpanExporter(endpoint=endpoint, headers=headers, timeout=1)
exporter._MAX_RETRY_TIMEOUT = 2
return exporter
@pytest.fixture(scope="module")
def _usernames(_fake: Faker) -> Iterator[_Username]:
return (_fake.unique.pystr() for _ in count())

return _

@pytest.fixture(scope="module")
def _profiles(
_emails: Iterator[_Email],
_passwords: Iterator[_Password],
_usernames: Iterator[_Username],
) -> Iterator[_Profile]:
return starmap(_Profile, zip(_emails, _passwords, _usernames))

@pytest.fixture(scope="session")
def grpc_span_exporter() -> _SpanExporterFactory:
def _(
*,
headers: Optional[_Headers] = None,
) -> SpanExporter:
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter

host = get_env_host()
if host == "0.0.0.0":
host = "127.0.0.1"
endpoint = f"http://{host}:{get_env_grpc_port()}"
return OTLPSpanExporter(endpoint=endpoint, headers=headers, timeout=1)
@pytest.fixture(scope="module")
def _users(
_profiles: Iterator[_Profile],
) -> _UserGenerator:
def _() -> Generator[Optional[_User], Tuple[UserRoleInput, Optional[_Profile]], None]:
role, profile = yield None
while True:
user = _DEFAULT_ADMIN.create_user(role, profile=profile or next(_profiles))
role, profile = yield user

return _
g = _()
next(g)
return cast(_UserGenerator, g)


@pytest.fixture(scope="session", params=["http", "grpc"])
def span_exporter(request: SubRequest) -> _SpanExporterFactory:
if request.param == "http":
return cast(_SpanExporterFactory, request.getfixturevalue("http_span_exporter"))
if request.param == "grpc":
return cast(_SpanExporterFactory, request.getfixturevalue("grpc_span_exporter"))
raise ValueError(f"Unknown exporter: {request.param}")


@pytest.fixture(scope="session")
def get_tracer() -> _GetTracer:
@pytest.fixture(scope="module")
def _new_user(
_users: _UserGenerator,
) -> _UserFactory:
def _(
role: UserRoleInput = _MEMBER,
/,
*,
project_name: str,
exporter: SpanExporter,
) -> Tracer:
resource = Resource({ResourceAttributes.PROJECT_NAME: project_name})
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(SimpleSpanProcessor(exporter))
return tracer_provider.get_tracer(__name__)
profile: Optional[_Profile] = None,
) -> _User:
return _users.send((role, profile))

return _


@pytest.fixture(scope="session")
def start_span(
get_tracer: _GetTracer,
) -> _StartSpan:
@pytest.fixture(scope="module")
def _get_user(
_new_user: _UserFactory,
) -> _GetUser:
def _(
role_or_user: _RoleOrUser = _MEMBER,
/,
*,
project_name: str,
span_name: str,
exporter: SpanExporter,
) -> Span:
return get_tracer(project_name=project_name, exporter=exporter).start_span(span_name)
profile: Optional[_Profile] = None,
) -> _User:
assert profile is None or isinstance(role_or_user, UserRoleInput)
if isinstance(role_or_user, _User):
user = role_or_user
return user
elif isinstance(role_or_user, UserRoleInput):
role = role_or_user
return _new_user(role, profile=profile)
else:
assert_never(role_or_user)

return _


@pytest.fixture(scope="session")
def httpx_client() -> Callable[[], httpx.Client]:
# Having no timeout is useful when stepping through the debugger on the server side.
return partial(httpx.Client, timeout=None)


@pytest.fixture(scope="session")
def server() -> Callable[[], ContextManager[None]]:
@contextmanager
def _() -> Iterator[None]:
if get_env_database_connection_str().startswith("postgresql"):
# double-check for safety
assert get_env_database_schema()
command = f"{sys.executable} -m phoenix.server.main serve"
process = Popen(command.split(), stdout=PIPE, stderr=STDOUT, text=True, env=os.environ)
log: List[str] = []
lock: Lock = Lock()
Thread(target=capture_stdout, args=(process, log, lock), daemon=True).start()
t = 60
time_limit = time() + t
timed_out = False
url = urljoin(get_base_url(), "healthz")
while not timed_out and is_alive(process):
sleep(0.1)
try:
urlopen(url)
break
except BaseException:
timed_out = time() > time_limit
try:
if timed_out:
raise TimeoutError(f"Server did not start within {t} seconds.")
assert is_alive(process)
with lock:
for line in log:
print(line, end="")
log.clear()
yield
process.terminate()
process.wait(10)
finally:
for line in log:
print(line, end="")

return _


def is_alive(process: Popen) -> bool:
return process.is_running() and process.status() != STATUS_ZOMBIE


def capture_stdout(process: Popen, log: List[str], lock: Lock) -> None:
while is_alive(process):
line = process.stdout.readline()
if line or (log and log[-1] != line):
with lock:
log.append(line)


@contextmanager
def _random_schema(url: URL, fake: Faker) -> Iterator[str]:
engine = create_engine(url.set(drivername="postgresql+psycopg"))
try:
engine.connect()
except OperationalError as ex:
pytest.skip(f"PostgreSQL unavailable: {ex}")
schema = fake.unique.pystr().lower()
yield schema
with engine.connect() as conn:
conn.execute(text(f"DROP SCHEMA IF EXISTS {schema} CASCADE;"))
conn.commit()
engine.dispose()
1 change: 0 additions & 1 deletion integration_tests/mypy.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
[mypy]
strict = true
explicit_package_bases = true
exclude = (^evals|^notebooks)
Loading

0 comments on commit 145b850

Please sign in to comment.