Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions .github/workflows/tests-studio.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,7 @@ jobs:
path: './backend/datachain'
fetch-depth: 0

- name: Install FFmpeg on Windows
if: runner.os == 'Windows'
run: choco install ffmpeg

- name: Install FFmpeg on macOS
if: runner.os == 'macOS'
run: |
brew install ffmpeg
echo 'DYLD_FALLBACK_LIBRARY_PATH=/opt/homebrew/lib' >> "$GITHUB_ENV"

- name: Install FFmpeg on Ubuntu
if: runner.os == 'Linux'
- name: Install FFmpeg
run: |
sudo apt update
sudo apt install -y ffmpeg
Expand All @@ -108,6 +97,35 @@ jobs:
- name: Install dependencies
run: uv pip install --system ./backend/datachain_server[tests] ./backend/datachain[tests]

- name: Initialize datachain venv
env:
PYTHON_VERSION: ${{ matrix.pyv }}
DATACHAIN_VENV_DIR: /tmp/local/datachain_venv/python${{ matrix.pyv }}
run: |
virtualenv -p "$(which python"${PYTHON_VERSION}")" "${DATACHAIN_VENV_DIR}"
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space in the command substitution syntax. The correct syntax should be \"$(which python\"${PYTHON_VERSION}\")\" or more clearly \"$(which python${PYTHON_VERSION})\". The current syntax python\"${PYTHON_VERSION}\" will not properly interpolate the variable.

Suggested change
virtualenv -p "$(which python"${PYTHON_VERSION}")" "${DATACHAIN_VENV_DIR}"
virtualenv -p "$(which python${PYTHON_VERSION})" "${DATACHAIN_VENV_DIR}"

Copilot uses AI. Check for mistakes.

pip_cache_dir="${DATACHAIN_VENV_DIR}/.cache/pip"
pip_wheel_dir="${pip_cache_dir}/wheels"
pip_bin="${DATACHAIN_VENV_DIR}/bin/pip"
mkdir -p "$pip_cache_dir"
mkdir -p "$pip_wheel_dir"

uv_cache_dir="${DATACHAIN_VENV_DIR}/.cache/uv"
mkdir -p "$uv_cache_dir"

$pip_bin install -U pip wheel setuptools \
--cache-dir="$pip_cache_dir"

$pip_bin wheel ./backend/datachain_server \
--wheel-dir="$pip_wheel_dir" \
--cache-dir="$pip_cache_dir"

uv venv --python "$PYTHON_VERSION" "${DATACHAIN_VENV_DIR}/default"
uv pip install -r ./backend/requirements-worker-venv.txt \
--find-links="$pip_wheel_dir" \
--cache-dir="${DATACHAIN_VENV_DIR}/.cache/uv" \
-p "${DATACHAIN_VENV_DIR}/default/bin/python"

- name: Run tests
# Generate `.test_durations` file with `pytest --store-durations --durations-path ../.github/.test_durations ...`
run: >
Expand Down
84 changes: 47 additions & 37 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from datachain.catalog.loader import get_metastore, get_warehouse
from datachain.cli.utils import CommaSeparatedArgs
from datachain.config import Config, ConfigLevel
from datachain.data_storage import AbstractMetastore, JobQueryType, JobStatus
from datachain.data_storage.sqlite import (
SQLiteDatabaseEngine,
SQLiteMetastore,
Expand Down Expand Up @@ -129,6 +130,15 @@ def clean_environment(
monkeypatch_session.delenv(DataChainDir.ENV_VAR_DATACHAIN_ROOT, raising=False)


def _create_job(metastore: AbstractMetastore) -> str:
return metastore.create_job(
"my-job",
'import datachain as dc; dc.read_values(num=[1, 2, 3]).save("nums")',
query_type=JobQueryType.PYTHON,
status=JobStatus.RUNNING,
)


@pytest.fixture
def sqlite_db():
if os.environ.get("DATACHAIN_METASTORE") or os.environ.get("DATACHAIN_WAREHOUSE"):
Expand Down Expand Up @@ -162,9 +172,13 @@ def cleanup_sqlite_db(


@pytest.fixture
def metastore():
def metastore(monkeypatch):
if os.environ.get("DATACHAIN_METASTORE"):
_metastore = get_metastore()

job_id = _create_job(_metastore)
monkeypatch.setenv("DATACHAIN_JOB_ID", job_id)

yield _metastore

_metastore.cleanup_for_tests()
Expand Down Expand Up @@ -233,14 +247,18 @@ def test_session(catalog):


@pytest.fixture
def metastore_tmpfile(tmp_path):
def metastore_tmpfile(monkeypatch, tmp_path):
if os.environ.get("DATACHAIN_METASTORE"):
_metastore = get_metastore()

job_id = _create_job(_metastore)
monkeypatch.setenv("DATACHAIN_JOB_ID", job_id)

yield _metastore

_metastore.cleanup_for_tests()
else:
_metastore = SQLiteMetastore(db_file=tmp_path / "test.db")
_metastore = SQLiteMetastore(db_file=str(tmp_path / "test.db"))
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting Path to string is unnecessary here. The SQLiteMetastore constructor likely accepts Path objects directly. If the change from tmp_path / \"test.db\" to str(tmp_path / \"test.db\") is intentional to fix an issue, it would be helpful to understand why, but this appears inconsistent with other code that uses paths directly.

Suggested change
_metastore = SQLiteMetastore(db_file=str(tmp_path / "test.db"))
_metastore = SQLiteMetastore(db_file=tmp_path / "test.db")

Copilot uses AI. Check for mistakes.
yield _metastore

cleanup_sqlite_db(_metastore.db.clone(), _metastore.default_table_names)
Expand All @@ -261,7 +279,7 @@ def warehouse_tmpfile(tmp_path, metastore_tmpfile):
finally:
_warehouse.cleanup_for_tests()
else:
_warehouse = SQLiteWarehouse(db_file=tmp_path / "test.db")
_warehouse = SQLiteWarehouse(db_file=str(tmp_path / "test.db"))
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting Path to string is unnecessary here. The SQLiteWarehouse constructor likely accepts Path objects directly. If the change from tmp_path / \"test.db\" to str(tmp_path / \"test.db\") is intentional to fix an issue, it would be helpful to understand why, but this appears inconsistent with other code that uses paths directly.

Suggested change
_warehouse = SQLiteWarehouse(db_file=str(tmp_path / "test.db"))
_warehouse = SQLiteWarehouse(db_file=tmp_path / "test.db")

Copilot uses AI. Check for mistakes.
yield _warehouse
try:
check_temp_tables_cleaned_up(_warehouse)
Expand Down Expand Up @@ -470,26 +488,21 @@ def cloud_server(request, tmp_upath_factory, cloud_type, version_aware, tree):

@pytest.fixture
def datachain_job_id(test_session, monkeypatch):
job_id = test_session.catalog.metastore.create_job(
"my-job",
'import datachain as dc; dc.read_values(num=[1, 2, 3].save("nums")',
)
monkeypatch.setenv("DATACHAIN_JOB_ID", job_id)
return job_id
yield _create_job(test_session.catalog.metastore)


@pytest.fixture
def cloud_server_credentials(cloud_server, monkeypatch):
@pytest.fixture(scope="session")
def cloud_server_credentials(cloud_server):
if cloud_server.kind == "s3":
cfg = cloud_server.src.fs.client_kwargs
try:
monkeypatch.delenv("AWS_PROFILE")
os.environ.pop("AWS_PROFILE")
except KeyError:
pass
monkeypatch.setenv("AWS_ACCESS_KEY_ID", cfg.get("aws_access_key_id"))
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", cfg.get("aws_secret_access_key"))
monkeypatch.setenv("AWS_SESSION_TOKEN", cfg.get("aws_session_token"))
monkeypatch.setenv("AWS_DEFAULT_REGION", cfg.get("region_name"))
os.environ["AWS_ACCESS_KEY_ID"] = cfg.get("aws_access_key_id")
os.environ["AWS_SECRET_ACCESS_KEY"] = cfg.get("aws_secret_access_key")
os.environ["AWS_SESSION_TOKEN"] = cfg.get("aws_session_token")
os.environ["AWS_DEFAULT_REGION"] = cfg.get("region_name")


def get_cloud_test_catalog(cloud_server, tmp_path, metastore, warehouse):
Expand Down Expand Up @@ -854,34 +867,22 @@ def pseudo_random_ds(test_session):


@pytest.fixture()
def run_datachain_worker(datachain_job_id):
def run_datachain_worker(monkeypatch):
if not os.environ.get("DATACHAIN_DISTRIBUTED"):
pytest.skip("Distributed tests are disabled")

job_id = os.environ.get("DATACHAIN_JOB_ID")
assert job_id, "DATACHAIN_JOB_ID environment variable is required for this test"

monkeypatch.delenv("DATACHAIN_DISTRIBUTED_DISABLED", raising=False)

# This worker can take several tasks in parallel, as it's very handy
# for testing, where we don't want [yet] to constrain the number of
# available workers.
workers = []
worker_cmd = [
"celery",
"-A",
"datachain_worker.tasks",
"worker",
"--loglevel=INFO",
"--hostname=tests-datachain-worker-main",
"--pool=solo",
"--concurrency=1",
"--max-tasks-per-child=1",
"--prefetch-multiplier=1",
"-Q",
f"datachain-worker-main-{job_id}",
]
print(f"Starting worker with command: {' '.join(worker_cmd)}")
workers.append(subprocess.Popen(worker_cmd, shell=False)) # noqa: S603
workers: list[subprocess.Popen] = []
queues: list[str] = []
for i in range(2):
queue_name = f"udf-{uuid.uuid4()}"
worker_cmd = [
"celery",
"-A",
Expand All @@ -894,10 +895,16 @@ def run_datachain_worker(datachain_job_id):
"--max-tasks-per-child=1",
"--prefetch-multiplier=1",
"-Q",
"udf_runner_queue",
queue_name,
]
queues.append(queue_name)
print(f"Starting worker with command: {' '.join(worker_cmd)}")
workers.append(subprocess.Popen(worker_cmd, shell=False)) # noqa: S603
worker_proc = subprocess.Popen( # noqa: S603
worker_cmd,
env=os.environ,
shell=False,
)
workers.append(worker_proc)
try:
from datachain_worker.utils.celery import celery_app

Expand All @@ -912,6 +919,9 @@ def run_datachain_worker(datachain_job_id):
if attempts == 10:
raise RuntimeError("Celery worker(s) did not start in time")

monkeypatch.setenv("DATACHAIN_STEP_ID", "1")
monkeypatch.setenv("UDF_RUNNER_QUEUE_NAME_LIST", ",".join(queues))

yield workers
finally:
for worker in workers:
Expand Down
4 changes: 4 additions & 0 deletions tests/func/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ def mock_is_script_run(monkeypatch):
monkeypatch.setattr("datachain.query.session.is_script_run", lambda: True)


@pytest.mark.skipif(
"os.environ.get('DATACHAIN_DISTRIBUTED')",
reason="Checkpoints test skipped in distributed mode",
)
Comment on lines +14 to +17
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The @pytest.mark.skipif condition uses a string expression but os is not imported in this file. This will cause a NameError at test collection time. Import os at the top of the file or use 'DATACHAIN_DISTRIBUTED' in os.environ with a proper import.

Copilot uses AI. Check for mistakes.
def test_checkpoints_parallel(test_session_tmpfile, monkeypatch):
def mapper_fail(num) -> int:
raise Exception("Error")
Expand Down
20 changes: 5 additions & 15 deletions tests/func/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,14 +553,8 @@ def name_len_error(_name):
.map(name_len_error, params=["file.path"], output={"name_len": int})
)

if os.environ.get("DATACHAIN_DISTRIBUTED"):
# in distributed mode we expect DataChainError with the error message
with pytest.raises(DataChainError, match="Test Error!"):
chain.show()
else:
# while in local mode we expect RuntimeError with the error message
with pytest.raises(RuntimeError, match="UDF Execution Failed!"):
chain.show()
with pytest.raises(RuntimeError, match="UDF Execution Failed!"):
chain.show()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -736,12 +730,8 @@ def name_len_interrupt(_name):
.settings(parallel=True)
.map(name_len_interrupt, params=["file.path"], output={"name_len": int})
)
if os.environ.get("DATACHAIN_DISTRIBUTED"):
with pytest.raises(KeyboardInterrupt):
chain.show()
else:
with pytest.raises(RuntimeError, match="UDF Execution Failed!"):
chain.show()
with pytest.raises(RuntimeError, match="UDF Execution Failed!"):
chain.show()
captured = capfd.readouterr()
assert "semaphore" not in captured.err

Expand Down Expand Up @@ -874,7 +864,7 @@ def name_len_interrupt(_name):
.settings(parallel=parallel, workers=workers)
.map(name_len_interrupt, params=["file.path"], output={"name_len": int})
)
with pytest.raises(KeyboardInterrupt):
with pytest.raises(Exception, match="UDF task failed with exit code"):
chain.show()
captured = capfd.readouterr()
assert "semaphore" not in captured.err
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/lib/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def nums_dataset(test_session):
return dc.read_values(num=[1, 2, 3], session=test_session).save("nums")


@pytest.mark.skipif(
"os.environ.get('DATACHAIN_DISTRIBUTED')",
reason="Checkpoints test skipped in distributed mode",
)
Comment on lines +24 to +27
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The @pytest.mark.skipif condition uses a string expression but os is not imported in this file. This will cause a NameError at test collection time. Import os at the top of the file or use 'DATACHAIN_DISTRIBUTED' in os.environ with a proper import.

Copilot uses AI. Check for mistakes.
@pytest.mark.parametrize("reset_checkpoints", [True, False])
@pytest.mark.parametrize("with_delta", [True, False])
@pytest.mark.parametrize("use_datachain_job_id_env", [True, False])
Expand Down Expand Up @@ -84,6 +88,10 @@ def test_checkpoints(
assert len(list(catalog.metastore.list_checkpoints(second_job_id))) == 3


@pytest.mark.skipif(
"os.environ.get('DATACHAIN_DISTRIBUTED')",
reason="Checkpoints test skipped in distributed mode",
)
Comment on lines +91 to +94
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The @pytest.mark.skipif condition uses a string expression but os is not imported in this file. This will cause a NameError at test collection time. Import os at the top of the file or use 'DATACHAIN_DISTRIBUTED' in os.environ with a proper import.

Copilot uses AI. Check for mistakes.
@pytest.mark.parametrize("reset_checkpoints", [True, False])
def test_checkpoints_modified_chains(
test_session, monkeypatch, nums_dataset, reset_checkpoints
Expand Down Expand Up @@ -115,6 +123,10 @@ def test_checkpoints_modified_chains(
assert len(list(catalog.metastore.list_checkpoints(second_job_id))) == 3


@pytest.mark.skipif(
"os.environ.get('DATACHAIN_DISTRIBUTED')",
reason="Checkpoints test skipped in distributed mode",
)
Comment on lines +126 to +129
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The @pytest.mark.skipif condition uses a string expression but os is not imported in this file. This will cause a NameError at test collection time. Import os at the top of the file or use 'DATACHAIN_DISTRIBUTED' in os.environ with a proper import.

Copilot uses AI. Check for mistakes.
@pytest.mark.parametrize("reset_checkpoints", [True, False])
def test_checkpoints_multiple_runs(
test_session, monkeypatch, nums_dataset, reset_checkpoints
Expand Down Expand Up @@ -180,6 +192,10 @@ def test_checkpoints_multiple_runs(
assert len(list(catalog.metastore.list_checkpoints(fourth_job_id))) == 3


@pytest.mark.skipif(
"os.environ.get('DATACHAIN_DISTRIBUTED')",
reason="Checkpoints test skipped in distributed mode",
)
Comment on lines +195 to +198
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The @pytest.mark.skipif condition uses a string expression but os is not imported in this file. This will cause a NameError at test collection time. Import os at the top of the file or use 'DATACHAIN_DISTRIBUTED' in os.environ with a proper import.

Copilot uses AI. Check for mistakes.
def test_checkpoints_check_valid_chain_is_returned(
test_session,
monkeypatch,
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def test_get_warehouse_in_memory():
warehouse.close()


def test_get_distributed_class():
def test_get_distributed_class(monkeypatch):
monkeypatch.delenv("DATACHAIN_DISTRIBUTED_DISABLED", raising=False)

with patch.dict(os.environ, {"DATACHAIN_DISTRIBUTED": ""}):
assert get_udf_distributor_class() is None

Expand Down
Loading