Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/docker/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
version: "3.3"
services:
ydb:
image: ydbplatform/local-ydb:trunk
restart: always
ports:
- "2136:2136"
hostname: localhost
environment:
- YDB_USE_IN_MEMORY_PDISKS=true
- YDB_ENABLE_COLUMN_TABLES=true
6 changes: 6 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ jobs:
run: |
poetry install

- name: Run docker compose
uses: hoverkraft-tech/compose-action@v2.0.1
with:
compose-file: "./.github/docker/docker-compose.yml"
up-flags: "--wait"

- name: Run tests
run: |
poetry run pytest tests
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Bump ydb depencency to 3.21.6

## 0.1.12 ##
* Ability to get view names

Expand Down
1,021 changes: 356 additions & 665 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@ repository = "https://github.com/ydb-platform/ydb-python-dbapi/"

[tool.poetry.dependencies]
python = "^3.8"
ydb = "^3.18.16"
ydb = "^3.21.6"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.5.0"
ruff = "^0.6.9"
mypy = "^1.11.2"
poethepoet = "0.28.0"
types-protobuf = "^5.28.0.20240924"
testcontainers = "^3.7.1"
pytest = "^8.3.3"
pytest-asyncio = "^0.24.0"
sqlalchemy = "^2.0.36"
Expand Down
126 changes: 22 additions & 104 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,118 +3,22 @@
from asyncio import AbstractEventLoop
from collections.abc import AsyncGenerator
from collections.abc import Generator
from concurrent.futures import TimeoutError
from typing import Any
from typing import Callable

import pytest
import ydb
from testcontainers.core.generic import DbContainer
from testcontainers.core.generic import wait_container_is_ready
from testcontainers.core.utils import setup_logger
from typing_extensions import Self

logger = setup_logger(__name__)


class YDBContainer(DbContainer):
def __init__(
self,
name: str | None = None,
port: str = "2136",
image: str = "ydbplatform/local-ydb:trunk",
**kwargs: Any,
) -> None:
docker_client_kw: dict[str, Any] = kwargs.pop("docker_client_kw", {})
docker_client_kw["timeout"] = docker_client_kw.get("timeout") or 300
super().__init__(
image=image,
hostname="localhost",
docker_client_kw=docker_client_kw,
**kwargs,
)
self.port_to_expose = port
self._name = name
self._database_name = "local"

def start(self) -> Self:
self._maybe_stop_old_container()
super().start()
return self

def get_connection_url(self, driver: str = "ydb") -> str:
host = self.get_container_host_ip()
port = self.get_exposed_port(self.port_to_expose)
return f"yql+{driver}://{host}:{port}/local"

def get_connection_string(self) -> str:
host = self.get_container_host_ip()
port = self.get_exposed_port(self.port_to_expose)
return f"grpc://{host}:{port}/?database=/local"

def get_ydb_database_name(self) -> str:
return self._database_name

def get_ydb_host(self) -> str:
return self.get_container_host_ip()

def get_ydb_port(self) -> str:
return self.get_exposed_port(self.port_to_expose)

@wait_container_is_ready(ydb.ConnectionError, TimeoutError)
def _connect(self) -> None:
with ydb.Driver(
connection_string=self.get_connection_string()
) as driver:
driver.wait(fail_fast=True)
try:
driver.scheme_client.describe_path("/local/.sys_health/test")
except ydb.SchemeError as e:
msg = "Database is not ready"
raise ydb.ConnectionError(msg) from e

def _configure(self) -> None:
self.with_bind_ports(self.port_to_expose, self.port_to_expose)
if self._name:
self.with_name(self._name)
self.with_env("YDB_USE_IN_MEMORY_PDISKS", "true")
self.with_env("YDB_DEFAULT_LOG_LEVEL", "DEBUG")
self.with_env("GRPC_PORT", self.port_to_expose)
self.with_env("GRPC_TLS_PORT", self.port_to_expose)

def _maybe_stop_old_container(self) -> None:
if not self._name:
return
docker_client = self.get_docker_client()
running_container = docker_client.client.api.containers(
filters={"name": self._name}
)
if running_container:
logger.info("Stop existing container")
docker_client.client.api.remove_container(
running_container[0], force=True, v=True
)


@pytest.fixture(scope="session")
def ydb_container(
unused_tcp_port_factory: Callable[[], int],
) -> Generator[YDBContainer, None, None]:
with YDBContainer(port=str(unused_tcp_port_factory())) as ydb_container:
yield ydb_container


@pytest.fixture(scope="session")
def connection_string(ydb_container: YDBContainer) -> str:
return ydb_container.get_connection_string()
@pytest.fixture
def connection_string() -> str:
return "grpc://localhost:2136/?database=/local"


@pytest.fixture(scope="session")
def connection_kwargs(ydb_container: YDBContainer) -> dict:
@pytest.fixture
def connection_kwargs() -> dict:
return {
"host": ydb_container.get_ydb_host(),
"port": ydb_container.get_ydb_port(),
"database": ydb_container.get_ydb_database_name(),
"host": "localhost",
"port": "2136",
"database": "/local",
}


Expand Down Expand Up @@ -176,6 +80,13 @@ async def session_pool(

yield session_pool

for name in ["table", "table1", "table2"]:
await session_pool.execute_with_retries(
f"""
DROP TABLE {name};
"""
)


@pytest.fixture
def session_pool_sync(
Expand Down Expand Up @@ -207,3 +118,10 @@ def session_pool_sync(
)

yield session_pool

for name in ["table", "table1", "table2"]:
session_pool.execute_with_retries(
f"""
DROP TABLE {name};
"""
)
81 changes: 81 additions & 0 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ def _test_error_with_interactive_tx(
with pytest.raises(dbapi.Error):
maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)"))

maybe_await(
cur.execute_scheme(
"""
DROP TABLE IF EXISTS test;
"""
)
)

maybe_await(cur.close())
maybe_await(connection.rollback())

Expand Down Expand Up @@ -272,6 +280,68 @@ def _test_get_view_names(
assert len(res) == 1
assert res[0] == "test_view"

maybe_await(
cur.execute_scheme(
"""
DROP VIEW test_view;
"""
)
)

maybe_await(cur.close())

def _test_get_table_names(
self,
connection: dbapi.Connection,
) -> None:
cur = connection.cursor()

row_table_name = "test_table_names_row"
column_table_name = "test_table_names_column"

maybe_await(
cur.execute_scheme(
f"""
DROP TABLE if exists {row_table_name};
DROP TABLE if exists {column_table_name};
"""
)
)

res = maybe_await(connection.get_table_names())

assert len(res) == 0

maybe_await(
cur.execute_scheme(
f"""
CREATE TABLE {row_table_name} (
id Utf8 NOT NULL,
PRIMARY KEY(id)
);
CREATE TABLE {column_table_name} (
id Utf8 NOT NULL,
PRIMARY KEY(id)
) WITH (STORE = COLUMN);
"""
)
)

res = maybe_await(connection.get_table_names())

assert len(res) == 2
assert row_table_name in res
assert column_table_name in res

maybe_await(
cur.execute_scheme(
f"""
DROP TABLE {row_table_name};
DROP TABLE {column_table_name};
"""
)
)

maybe_await(cur.close())


Expand Down Expand Up @@ -329,6 +399,11 @@ def test_get_view_names(
) -> None:
self._test_get_view_names(connection)

def test_get_table_names(
self, connection: dbapi.Connection
) -> None:
self._test_get_table_names(connection)


class TestAsyncConnection(BaseDBApiTestSuit):
@pytest_asyncio.fixture
Expand Down Expand Up @@ -404,3 +479,9 @@ async def test_get_view_names(
self, connection: dbapi.AsyncConnection
) -> None:
await greenlet_spawn(self._test_get_view_names, connection)

@pytest.mark.asyncio
async def test_get_table_names(
self, connection: dbapi.AsyncConnection
) -> None:
await greenlet_spawn(self._test_get_table_names, connection)
25 changes: 15 additions & 10 deletions ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,18 @@ def check_exists(self, table_path: str) -> bool:
@handle_ydb_errors
def get_table_names(self) -> list[str]:
abs_dir_path = posixpath.join(self.database, self.table_path_prefix)
names = self._get_entity_names(abs_dir_path, ydb.SchemeEntryType.TABLE)
names = self._get_entity_names(
abs_dir_path,
[ydb.SchemeEntryType.TABLE, ydb.SchemeEntryType.COLUMN_TABLE],
)
return [posixpath.relpath(path, abs_dir_path) for path in names]

@handle_ydb_errors
def get_view_names(self) -> list[str]:
abs_dir_path = posixpath.join(self.database, self.table_path_prefix)
names = self._get_entity_names(abs_dir_path, ydb.SchemeEntryType.VIEW)
names = self._get_entity_names(
abs_dir_path, [ydb.SchemeEntryType.VIEW]
)
return [posixpath.relpath(path, abs_dir_path) for path in names]

def _check_path_exists(self, table_path: str) -> bool:
Expand All @@ -302,7 +307,7 @@ def callee() -> None:
return True

def _get_entity_names(
self, abs_dir_path: str, etype: ydb.SchemeEntryType
self, abs_dir_path: str, etypes: list[ydb.SchemeEntryType]
) -> list[str]:
settings = self._get_request_settings()

Expand All @@ -316,10 +321,10 @@ def callee() -> ydb.Directory:
result = []
for child in directory.children:
child_abs_path = posixpath.join(abs_dir_path, child.name)
if child.type == etype:
if child.type in etypes:
result.append(child_abs_path)
elif child.is_directory() and not child.name.startswith("."):
result.extend(self._get_entity_names(child_abs_path, etype))
result.extend(self._get_entity_names(child_abs_path, etypes))
return result

@handle_ydb_errors
Expand Down Expand Up @@ -462,7 +467,7 @@ async def get_table_names(self) -> list[str]:
abs_dir_path = posixpath.join(self.database, self.table_path_prefix)
names = await self._get_entity_names(
abs_dir_path,
ydb.SchemeEntryType.TABLE,
[ydb.SchemeEntryType.TABLE, ydb.SchemeEntryType.COLUMN_TABLE],
)
return [posixpath.relpath(path, abs_dir_path) for path in names]

Expand All @@ -471,7 +476,7 @@ async def get_view_names(self) -> list[str]:
abs_dir_path = posixpath.join(self.database, self.table_path_prefix)
names = await self._get_entity_names(
abs_dir_path,
ydb.SchemeEntryType.VIEW,
[ydb.SchemeEntryType.VIEW],
)
return [posixpath.relpath(path, abs_dir_path) for path in names]

Expand All @@ -492,7 +497,7 @@ async def callee() -> None:
return True

async def _get_entity_names(
self, abs_dir_path: str, etype: ydb.SchemeEntryType
self, abs_dir_path: str, etypes: list[ydb.SchemeEntryType]
) -> list[str]:
settings = self._get_request_settings()

Expand All @@ -506,11 +511,11 @@ async def callee() -> ydb.Directory:
result = []
for child in directory.children:
child_abs_path = posixpath.join(abs_dir_path, child.name)
if child.type == etype:
if child.type in etypes:
result.append(child_abs_path)
elif child.is_directory() and not child.name.startswith("."):
result.extend(
await self._get_entity_names(child_abs_path, etype)
await self._get_entity_names(child_abs_path, etypes)
)
return result

Expand Down