From ed7f42b0d632511d99e733561bffd66bca9a1588 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Wed, 4 Sep 2024 20:11:33 +0545 Subject: [PATCH] query: remove use of pipe for communication --- src/datachain/catalog/catalog.py | 114 ++++++++++++------------------- src/datachain/job.py | 7 +- src/datachain/query/dataset.py | 23 ------- tests/func/test_catalog.py | 29 ++------ tests/func/test_query.py | 11 +-- 5 files changed, 60 insertions(+), 124 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 78d6f1460..e6a021e09 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -1416,7 +1416,8 @@ def list_datasets_versions( for d in datasets: yield from ( - (d, v, jobs.get(v.job_id) if v.job_id else None) for v in d.versions + (d, v, jobs.get(str(v.job_id)) if v.job_id else None) + for v in d.versions ) def ls_dataset_rows( @@ -1864,14 +1865,22 @@ def query( C.size > 1000 ) """ - feature_file = tempfile.NamedTemporaryFile( # noqa: SIM115 dir=os.getcwd(), suffix=".py", delete=False ) _, feature_module = os.path.split(feature_file.name) + if not job_id: + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + job_id = self.metastore.create_job( + name="", + query=query_script, + params=params, + python_version=python_version, + ) + try: - lines, proc, response_text = self.run_query( + lines, proc = self.run_query( python_executable or sys.executable, query_script, envs, @@ -1908,19 +1917,38 @@ def query( output=output, ) + def _get_dataset_versions_by_job_id(): + for dr, dv, job in self.list_datasets_versions(): + if job and str(job.id) == job_id: + yield dr, dv + try: - result = json.loads(response_text) - except ValueError: - result = None - - dataset: Optional[DatasetRecord] = None - version: Optional[int] = None - if save: - dataset, version = self.save_result( - query_script, result, output, version, job_id + dr, dv = max( + _get_dataset_versions_by_job_id(), key=lambda x: x[1].created_at ) + except ValueError as e: + if not save: + return QueryResult(dataset=None, version=None, output=output) + + raise QueryScriptDatasetNotFound( + "No dataset found after running Query script", + output=output, + ) from e - return QueryResult(dataset=dataset, version=version, output=output) + dr = self.update_dataset( + dr, + script_output=output, + query_script=query_script, + ) + self.update_dataset_version_with_warehouse_info( + dr, + dv.version, + script_output=output, + query_script=query_script, + job_id=job_id, + is_job_result=True, + ) + return QueryResult(dataset=dr, version=dv.version, output=output) def run_query( self, @@ -1934,7 +1962,7 @@ def run_query( params: Optional[dict[str, str]], save: bool, job_id: Optional[str], - ) -> tuple[list[str], subprocess.Popen, str]: + ) -> tuple[list[str], subprocess.Popen]: try: feature_code, query_script_compiled = self.compile_query_script( query_script, feature_module[:-3] @@ -1947,19 +1975,6 @@ def run_query( raise QueryScriptCompileError( f"Query script failed to compile, reason: {exc}" ) from exc - r, w = os.pipe() - if os.name == "nt": - import msvcrt - - os.set_inheritable(w, True) - - startupinfo = subprocess.STARTUPINFO() # type: ignore[attr-defined] - handle = msvcrt.get_osfhandle(w) # type: ignore[attr-defined] - startupinfo.lpAttributeList["handle_list"].append(handle) - kwargs: dict[str, Any] = {"startupinfo": startupinfo} - else: - handle = w - kwargs = {"pass_fds": [w]} envs = dict(envs or os.environ) if feature_code: envs["DATACHAIN_FEATURE_CLASS_SOURCE"] = json.dumps( @@ -1971,7 +1986,6 @@ def run_query( "PYTHONPATH": os.getcwd(), # For local imports "DATACHAIN_QUERY_SAVE": "1" if save else "", "PYTHONUNBUFFERED": "1", - "DATACHAIN_OUTPUT_FD": str(handle), "DATACHAIN_JOB_ID": job_id or "", }, ) @@ -1982,52 +1996,12 @@ def run_query( stderr=subprocess.STDOUT if capture_output else None, bufsize=1, text=False, - **kwargs, ) as proc: - os.close(w) - out = proc.stdout _lines: list[str] = [] ctx = print_and_capture(out, output_hook) if out else nullcontext(_lines) - - with ctx as lines, open(r) as f: - response_text = "" - while proc.poll() is None: - response_text += f.readline() - time.sleep(0.1) - response_text += f.readline() - return lines, proc, response_text - - def save_result(self, query_script, exec_result, output, version, job_id): - if not exec_result: - raise QueryScriptDatasetNotFound( - "No dataset found after running Query script", - output=output, - ) - name, version = exec_result - # finding returning dataset - try: - dataset = self.get_dataset(name) - dataset.get_version(version) - except (DatasetNotFoundError, ValueError) as e: - raise QueryScriptDatasetNotFound( - "No dataset found after running Query script", - output=output, - ) from e - dataset = self.update_dataset( - dataset, - script_output=output, - query_script=query_script, - ) - self.update_dataset_version_with_warehouse_info( - dataset, - version, - script_output=output, - query_script=query_script, - job_id=job_id, - is_job_result=True, - ) - return dataset, version + with ctx as lines: + return lines, proc def cp( self, diff --git a/src/datachain/job.py b/src/datachain/job.py index b5653335a..93b134d8a 100644 --- a/src/datachain/job.py +++ b/src/datachain/job.py @@ -1,7 +1,8 @@ import json +import uuid from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional, TypeVar +from typing import Any, Optional, TypeVar, Union J = TypeVar("J", bound="Job") @@ -25,7 +26,7 @@ class Job: @classmethod def parse( cls: type[J], - id: str, + id: Union[str, uuid.UUID], name: str, status: int, created_at: datetime, @@ -40,7 +41,7 @@ def parse( metrics: str, ) -> "Job": return cls( - id, + str(id), name, status, created_at, diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 5150dcfe7..d0a6f3c0b 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -1,6 +1,5 @@ import contextlib import inspect -import json import logging import os import random @@ -1710,19 +1709,6 @@ def save( return self.__class__(name=name, version=version, catalog=self.catalog) -def _get_output_fd_for_write() -> Union[str, int]: - handle = os.getenv("DATACHAIN_OUTPUT_FD") - if not handle: - return os.devnull - - if os.name != "nt": - return int(handle) - - import msvcrt - - return msvcrt.open_osfhandle(int(handle), os.O_WRONLY) # type: ignore[attr-defined] - - def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery: """ Wrapper function that wraps the last statement of user query script. @@ -1742,13 +1728,4 @@ def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery: if save and (is_session_temp_dataset or not dataset_query.attached): name = catalog.generate_query_dataset_name() dataset_query = dataset_query.save(name) - - dataset: Optional[tuple[str, int]] = None - if dataset_query.attached: - assert dataset_query.name, "Dataset name should be provided" - assert dataset_query.version, "Dataset version should be provided" - dataset = dataset_query.name, dataset_query.version - - with open(_get_output_fd_for_write(), mode="w") as f: - json.dump(dataset, f) return dataset_query diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index 2c93476c1..d3879a36d 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -1,7 +1,5 @@ import io -import json import os -from contextlib import suppress from pathlib import Path from textwrap import dedent from urllib.parse import urlparse @@ -45,20 +43,6 @@ def pre_created_ds_name(): return "pre_created_dataset" -@pytest.fixture -def mock_os_pipe(mocker): - r, w = os.pipe() - mocker.patch("os.pipe", return_value=(r, w)) - - try: - yield (r, w) - finally: - with suppress(OSError): - os.close(r) - with suppress(OSError): - os.close(w) - - @pytest.fixture def mock_popen(mocker): m = mocker.patch( @@ -72,21 +56,20 @@ def mock_popen(mocker): @pytest.fixture def mock_popen_dataset_created( - mock_popen, cloud_test_catalog, mock_os_pipe, listed_bucket + mocker, monkeypatch, mock_popen, cloud_test_catalog, listed_bucket ): # create dataset which would be created in subprocess ds_name = cloud_test_catalog.catalog.generate_query_dataset_name() - ds_version = 1 + job_id = cloud_test_catalog.catalog.metastore.create_job(name="", query="") + mocker.patch.object( + cloud_test_catalog.catalog.metastore, "create_job", return_value=job_id + ) + monkeypatch.setenv("DATACHAIN_JOB_ID", str(job_id)) cloud_test_catalog.catalog.create_dataset_from_sources( ds_name, [f"{cloud_test_catalog.src_uri}/dogs/*"], recursive=True, ) - - _, w = mock_os_pipe - with open(w, mode="w", closefd=False) as f: - f.write(json.dumps((ds_name, ds_version))) - mock_popen.configure_mock(stdout=io.StringIO("user log 1\nuser log 2")) yield mock_popen diff --git a/tests/func/test_query.py b/tests/func/test_query.py index ae4830ca8..20810878d 100644 --- a/tests/func/test_query.py +++ b/tests/func/test_query.py @@ -186,15 +186,16 @@ def test_query( query_script = setup_catalog(query_script, catalog_info_filepath) result = catalog.query(query_script, save=save) - if not save: - assert result.dataset is None - return - if save_dataset: assert result.dataset.name == save_dataset assert catalog.get_dataset(save_dataset) - else: + elif save: assert result.dataset.name.startswith(QUERY_DATASET_PREFIX) + else: + assert result.dataset is None + assert result.version is None + return + assert result.version == 1 assert result.dataset.versions_values == [1] assert result.dataset.query_script == query_script