Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

query: remove use of pipe for communication #393

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 44 additions & 70 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,7 +1416,8 @@ def list_datasets_versions(

for d in datasets:
yield from (
(d, v, jobs.get(v.job_id) if v.job_id else None) for v in d.versions
(d, v, jobs.get(str(v.job_id)) if v.job_id else None)
for v in d.versions
)

def ls_dataset_rows(
Expand Down Expand Up @@ -1864,14 +1865,22 @@ def query(
C.size > 1000
)
"""

feature_file = tempfile.NamedTemporaryFile( # noqa: SIM115
dir=os.getcwd(), suffix=".py", delete=False
)
_, feature_module = os.path.split(feature_file.name)

if not job_id:
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
job_id = self.metastore.create_job(
name="",
query=query_script,
params=params,
python_version=python_version,
)

try:
lines, proc, response_text = self.run_query(
lines, proc = self.run_query(
python_executable or sys.executable,
query_script,
envs,
Expand Down Expand Up @@ -1908,19 +1917,38 @@ def query(
output=output,
)

def _get_dataset_versions_by_job_id():
for dr, dv, job in self.list_datasets_versions():
if job and str(job.id) == job_id:
yield dr, dv

try:
result = json.loads(response_text)
except ValueError:
result = None

dataset: Optional[DatasetRecord] = None
version: Optional[int] = None
if save:
dataset, version = self.save_result(
query_script, result, output, version, job_id
dr, dv = max(
_get_dataset_versions_by_job_id(), key=lambda x: x[1].created_at
)
except ValueError as e:
if not save:
return QueryResult(dataset=None, version=None, output=output)

raise QueryScriptDatasetNotFound(
"No dataset found after running Query script",
output=output,
) from e

return QueryResult(dataset=dataset, version=version, output=output)
dr = self.update_dataset(
dr,
script_output=output,
query_script=query_script,
)
self.update_dataset_version_with_warehouse_info(
dr,
dv.version,
script_output=output,
query_script=query_script,
job_id=job_id,
is_job_result=True,
)
return QueryResult(dataset=dr, version=dv.version, output=output)

def run_query(
self,
Expand All @@ -1934,7 +1962,7 @@ def run_query(
params: Optional[dict[str, str]],
save: bool,
job_id: Optional[str],
) -> tuple[list[str], subprocess.Popen, str]:
) -> tuple[list[str], subprocess.Popen]:
try:
feature_code, query_script_compiled = self.compile_query_script(
query_script, feature_module[:-3]
Expand All @@ -1947,19 +1975,6 @@ def run_query(
raise QueryScriptCompileError(
f"Query script failed to compile, reason: {exc}"
) from exc
r, w = os.pipe()
if os.name == "nt":
import msvcrt

os.set_inheritable(w, True)

startupinfo = subprocess.STARTUPINFO() # type: ignore[attr-defined]
handle = msvcrt.get_osfhandle(w) # type: ignore[attr-defined]
startupinfo.lpAttributeList["handle_list"].append(handle)
kwargs: dict[str, Any] = {"startupinfo": startupinfo}
else:
handle = w
kwargs = {"pass_fds": [w]}
envs = dict(envs or os.environ)
if feature_code:
envs["DATACHAIN_FEATURE_CLASS_SOURCE"] = json.dumps(
Expand All @@ -1971,7 +1986,6 @@ def run_query(
"PYTHONPATH": os.getcwd(), # For local imports
"DATACHAIN_QUERY_SAVE": "1" if save else "",
"PYTHONUNBUFFERED": "1",
"DATACHAIN_OUTPUT_FD": str(handle),
"DATACHAIN_JOB_ID": job_id or "",
},
)
Expand All @@ -1982,52 +1996,12 @@ def run_query(
stderr=subprocess.STDOUT if capture_output else None,
bufsize=1,
text=False,
**kwargs,
) as proc:
os.close(w)

out = proc.stdout
_lines: list[str] = []
ctx = print_and_capture(out, output_hook) if out else nullcontext(_lines)

with ctx as lines, open(r) as f:
response_text = ""
while proc.poll() is None:
response_text += f.readline()
time.sleep(0.1)
response_text += f.readline()
return lines, proc, response_text

def save_result(self, query_script, exec_result, output, version, job_id):
if not exec_result:
raise QueryScriptDatasetNotFound(
"No dataset found after running Query script",
output=output,
)
name, version = exec_result
# finding returning dataset
try:
dataset = self.get_dataset(name)
dataset.get_version(version)
except (DatasetNotFoundError, ValueError) as e:
raise QueryScriptDatasetNotFound(
"No dataset found after running Query script",
output=output,
) from e
dataset = self.update_dataset(
dataset,
script_output=output,
query_script=query_script,
)
self.update_dataset_version_with_warehouse_info(
dataset,
version,
script_output=output,
query_script=query_script,
job_id=job_id,
is_job_result=True,
)
return dataset, version
with ctx as lines:
return lines, proc

def cp(
self,
Expand Down
7 changes: 4 additions & 3 deletions src/datachain/job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional, TypeVar
from typing import Any, Optional, TypeVar, Union

J = TypeVar("J", bound="Job")

Expand All @@ -25,7 +26,7 @@ class Job:
@classmethod
def parse(
cls: type[J],
id: str,
id: Union[str, uuid.UUID],
name: str,
status: int,
created_at: datetime,
Expand All @@ -40,7 +41,7 @@ def parse(
metrics: str,
) -> "Job":
return cls(
id,
str(id),
name,
status,
created_at,
Expand Down
23 changes: 0 additions & 23 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import contextlib
import inspect
import json
import logging
import os
import random
Expand Down Expand Up @@ -1710,19 +1709,6 @@ def save(
return self.__class__(name=name, version=version, catalog=self.catalog)


def _get_output_fd_for_write() -> Union[str, int]:
handle = os.getenv("DATACHAIN_OUTPUT_FD")
if not handle:
return os.devnull

if os.name != "nt":
return int(handle)

import msvcrt

return msvcrt.open_osfhandle(int(handle), os.O_WRONLY) # type: ignore[attr-defined]


def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:
"""
Wrapper function that wraps the last statement of user query script.
Expand All @@ -1742,13 +1728,4 @@ def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:
if save and (is_session_temp_dataset or not dataset_query.attached):
name = catalog.generate_query_dataset_name()
dataset_query = dataset_query.save(name)

dataset: Optional[tuple[str, int]] = None
if dataset_query.attached:
assert dataset_query.name, "Dataset name should be provided"
assert dataset_query.version, "Dataset version should be provided"
dataset = dataset_query.name, dataset_query.version

with open(_get_output_fd_for_write(), mode="w") as f:
json.dump(dataset, f)
return dataset_query
29 changes: 6 additions & 23 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import io
import json
import os
from contextlib import suppress
from pathlib import Path
from textwrap import dedent
from urllib.parse import urlparse
Expand Down Expand Up @@ -45,20 +43,6 @@ def pre_created_ds_name():
return "pre_created_dataset"


@pytest.fixture
def mock_os_pipe(mocker):
r, w = os.pipe()
mocker.patch("os.pipe", return_value=(r, w))

try:
yield (r, w)
finally:
with suppress(OSError):
os.close(r)
with suppress(OSError):
os.close(w)


@pytest.fixture
def mock_popen(mocker):
m = mocker.patch(
Expand All @@ -72,21 +56,20 @@ def mock_popen(mocker):

@pytest.fixture
def mock_popen_dataset_created(
mock_popen, cloud_test_catalog, mock_os_pipe, listed_bucket
mocker, monkeypatch, mock_popen, cloud_test_catalog, listed_bucket
):
# create dataset which would be created in subprocess
ds_name = cloud_test_catalog.catalog.generate_query_dataset_name()
ds_version = 1
job_id = cloud_test_catalog.catalog.metastore.create_job(name="", query="")
mocker.patch.object(
cloud_test_catalog.catalog.metastore, "create_job", return_value=job_id
)
monkeypatch.setenv("DATACHAIN_JOB_ID", str(job_id))
cloud_test_catalog.catalog.create_dataset_from_sources(
ds_name,
[f"{cloud_test_catalog.src_uri}/dogs/*"],
recursive=True,
)

_, w = mock_os_pipe
with open(w, mode="w", closefd=False) as f:
f.write(json.dumps((ds_name, ds_version)))

mock_popen.configure_mock(stdout=io.StringIO("user log 1\nuser log 2"))
yield mock_popen

Expand Down
11 changes: 6 additions & 5 deletions tests/func/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,16 @@ def test_query(
query_script = setup_catalog(query_script, catalog_info_filepath)

result = catalog.query(query_script, save=save)
if not save:
assert result.dataset is None
return

if save_dataset:
assert result.dataset.name == save_dataset
assert catalog.get_dataset(save_dataset)
else:
elif save:
assert result.dataset.name.startswith(QUERY_DATASET_PREFIX)
else:
assert result.dataset is None
assert result.version is None
return

assert result.version == 1
assert result.dataset.versions_values == [1]
assert result.dataset.query_script == query_script
Expand Down