Skip to content

Commit

Permalink
query: remove use of pipe for communication
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Sep 5, 2024
1 parent 576b69a commit ed7f42b
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 124 deletions.
114 changes: 44 additions & 70 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,7 +1416,8 @@ def list_datasets_versions(

for d in datasets:
yield from (
(d, v, jobs.get(v.job_id) if v.job_id else None) for v in d.versions
(d, v, jobs.get(str(v.job_id)) if v.job_id else None)
for v in d.versions
)

def ls_dataset_rows(
Expand Down Expand Up @@ -1864,14 +1865,22 @@ def query(
C.size > 1000
)
"""

feature_file = tempfile.NamedTemporaryFile( # noqa: SIM115
dir=os.getcwd(), suffix=".py", delete=False
)
_, feature_module = os.path.split(feature_file.name)

if not job_id:
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
job_id = self.metastore.create_job(
name="",
query=query_script,
params=params,
python_version=python_version,
)

try:
lines, proc, response_text = self.run_query(
lines, proc = self.run_query(
python_executable or sys.executable,
query_script,
envs,
Expand Down Expand Up @@ -1908,19 +1917,38 @@ def query(
output=output,
)

def _get_dataset_versions_by_job_id():
for dr, dv, job in self.list_datasets_versions():
if job and str(job.id) == job_id:
yield dr, dv

try:
result = json.loads(response_text)
except ValueError:
result = None

dataset: Optional[DatasetRecord] = None
version: Optional[int] = None
if save:
dataset, version = self.save_result(
query_script, result, output, version, job_id
dr, dv = max(
_get_dataset_versions_by_job_id(), key=lambda x: x[1].created_at
)
except ValueError as e:
if not save:
return QueryResult(dataset=None, version=None, output=output)

raise QueryScriptDatasetNotFound(
"No dataset found after running Query script",
output=output,
) from e

return QueryResult(dataset=dataset, version=version, output=output)
dr = self.update_dataset(
dr,
script_output=output,
query_script=query_script,
)
self.update_dataset_version_with_warehouse_info(
dr,
dv.version,
script_output=output,
query_script=query_script,
job_id=job_id,
is_job_result=True,
)
return QueryResult(dataset=dr, version=dv.version, output=output)

def run_query(
self,
Expand All @@ -1934,7 +1962,7 @@ def run_query(
params: Optional[dict[str, str]],
save: bool,
job_id: Optional[str],
) -> tuple[list[str], subprocess.Popen, str]:
) -> tuple[list[str], subprocess.Popen]:
try:
feature_code, query_script_compiled = self.compile_query_script(
query_script, feature_module[:-3]
Expand All @@ -1947,19 +1975,6 @@ def run_query(
raise QueryScriptCompileError(
f"Query script failed to compile, reason: {exc}"
) from exc
r, w = os.pipe()
if os.name == "nt":
import msvcrt

os.set_inheritable(w, True)

startupinfo = subprocess.STARTUPINFO() # type: ignore[attr-defined]
handle = msvcrt.get_osfhandle(w) # type: ignore[attr-defined]
startupinfo.lpAttributeList["handle_list"].append(handle)
kwargs: dict[str, Any] = {"startupinfo": startupinfo}
else:
handle = w
kwargs = {"pass_fds": [w]}
envs = dict(envs or os.environ)
if feature_code:
envs["DATACHAIN_FEATURE_CLASS_SOURCE"] = json.dumps(
Expand All @@ -1971,7 +1986,6 @@ def run_query(
"PYTHONPATH": os.getcwd(), # For local imports
"DATACHAIN_QUERY_SAVE": "1" if save else "",
"PYTHONUNBUFFERED": "1",
"DATACHAIN_OUTPUT_FD": str(handle),
"DATACHAIN_JOB_ID": job_id or "",
},
)
Expand All @@ -1982,52 +1996,12 @@ def run_query(
stderr=subprocess.STDOUT if capture_output else None,
bufsize=1,
text=False,
**kwargs,
) as proc:
os.close(w)

out = proc.stdout
_lines: list[str] = []
ctx = print_and_capture(out, output_hook) if out else nullcontext(_lines)

with ctx as lines, open(r) as f:
response_text = ""
while proc.poll() is None:
response_text += f.readline()
time.sleep(0.1)
response_text += f.readline()
return lines, proc, response_text

def save_result(self, query_script, exec_result, output, version, job_id):
if not exec_result:
raise QueryScriptDatasetNotFound(
"No dataset found after running Query script",
output=output,
)
name, version = exec_result
# finding returning dataset
try:
dataset = self.get_dataset(name)
dataset.get_version(version)
except (DatasetNotFoundError, ValueError) as e:
raise QueryScriptDatasetNotFound(
"No dataset found after running Query script",
output=output,
) from e
dataset = self.update_dataset(
dataset,
script_output=output,
query_script=query_script,
)
self.update_dataset_version_with_warehouse_info(
dataset,
version,
script_output=output,
query_script=query_script,
job_id=job_id,
is_job_result=True,
)
return dataset, version
with ctx as lines:
return lines, proc

def cp(
self,
Expand Down
7 changes: 4 additions & 3 deletions src/datachain/job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional, TypeVar
from typing import Any, Optional, TypeVar, Union

J = TypeVar("J", bound="Job")

Expand All @@ -25,7 +26,7 @@ class Job:
@classmethod
def parse(
cls: type[J],
id: str,
id: Union[str, uuid.UUID],
name: str,
status: int,
created_at: datetime,
Expand All @@ -40,7 +41,7 @@ def parse(
metrics: str,
) -> "Job":
return cls(
id,
str(id),
name,
status,
created_at,
Expand Down
23 changes: 0 additions & 23 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import contextlib
import inspect
import json
import logging
import os
import random
Expand Down Expand Up @@ -1710,19 +1709,6 @@ def save(
return self.__class__(name=name, version=version, catalog=self.catalog)


def _get_output_fd_for_write() -> Union[str, int]:
handle = os.getenv("DATACHAIN_OUTPUT_FD")
if not handle:
return os.devnull

if os.name != "nt":
return int(handle)

import msvcrt

return msvcrt.open_osfhandle(int(handle), os.O_WRONLY) # type: ignore[attr-defined]


def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:
"""
Wrapper function that wraps the last statement of user query script.
Expand All @@ -1742,13 +1728,4 @@ def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:
if save and (is_session_temp_dataset or not dataset_query.attached):
name = catalog.generate_query_dataset_name()
dataset_query = dataset_query.save(name)

dataset: Optional[tuple[str, int]] = None
if dataset_query.attached:
assert dataset_query.name, "Dataset name should be provided"
assert dataset_query.version, "Dataset version should be provided"
dataset = dataset_query.name, dataset_query.version

with open(_get_output_fd_for_write(), mode="w") as f:
json.dump(dataset, f)
return dataset_query
29 changes: 6 additions & 23 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import io
import json
import os
from contextlib import suppress
from pathlib import Path
from textwrap import dedent
from urllib.parse import urlparse
Expand Down Expand Up @@ -45,20 +43,6 @@ def pre_created_ds_name():
return "pre_created_dataset"


@pytest.fixture
def mock_os_pipe(mocker):
r, w = os.pipe()
mocker.patch("os.pipe", return_value=(r, w))

try:
yield (r, w)
finally:
with suppress(OSError):
os.close(r)
with suppress(OSError):
os.close(w)


@pytest.fixture
def mock_popen(mocker):
m = mocker.patch(
Expand All @@ -72,21 +56,20 @@ def mock_popen(mocker):

@pytest.fixture
def mock_popen_dataset_created(
mock_popen, cloud_test_catalog, mock_os_pipe, listed_bucket
mocker, monkeypatch, mock_popen, cloud_test_catalog, listed_bucket
):
# create dataset which would be created in subprocess
ds_name = cloud_test_catalog.catalog.generate_query_dataset_name()
ds_version = 1
job_id = cloud_test_catalog.catalog.metastore.create_job(name="", query="")
mocker.patch.object(
cloud_test_catalog.catalog.metastore, "create_job", return_value=job_id
)
monkeypatch.setenv("DATACHAIN_JOB_ID", str(job_id))
cloud_test_catalog.catalog.create_dataset_from_sources(
ds_name,
[f"{cloud_test_catalog.src_uri}/dogs/*"],
recursive=True,
)

_, w = mock_os_pipe
with open(w, mode="w", closefd=False) as f:
f.write(json.dumps((ds_name, ds_version)))

mock_popen.configure_mock(stdout=io.StringIO("user log 1\nuser log 2"))
yield mock_popen

Expand Down
11 changes: 6 additions & 5 deletions tests/func/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,16 @@ def test_query(
query_script = setup_catalog(query_script, catalog_info_filepath)

result = catalog.query(query_script, save=save)
if not save:
assert result.dataset is None
return

if save_dataset:
assert result.dataset.name == save_dataset
assert catalog.get_dataset(save_dataset)
else:
elif save:
assert result.dataset.name.startswith(QUERY_DATASET_PREFIX)
else:
assert result.dataset is None
assert result.version is None
return

assert result.version == 1
assert result.dataset.versions_values == [1]
assert result.dataset.query_script == query_script
Expand Down

0 comments on commit ed7f42b

Please sign in to comment.