Skip to content

Commit

Permalink
cli: remove preview from datachain query command (#392)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Sep 4, 2024
1 parent cbd20f2 commit 2952eb1
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 108 deletions.
36 changes: 6 additions & 30 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ class QueryResult(NamedTuple):
dataset: Optional[DatasetRecord]
version: Optional[int]
output: str
preview: Optional[list[dict]]


class DatasetRowsFetcher(NodesThreadPool):
Expand Down Expand Up @@ -1861,9 +1860,6 @@ def query(
envs: Optional[Mapping[str, str]] = None,
python_executable: Optional[str] = None,
save: bool = False,
preview_limit: int = 10,
preview_offset: int = 0,
preview_columns: Optional[list[str]] = None,
capture_output: bool = True,
output_hook: Callable[[str], None] = noop,
params: Optional[dict[str, str]] = None,
Expand Down Expand Up @@ -1891,7 +1887,6 @@ def query(
C.size > 1000
)
"""
from datachain.query.dataset import ExecutionResult

feature_file = tempfile.NamedTemporaryFile( # noqa: SIM115
dir=os.getcwd(), suffix=".py", delete=False
Expand All @@ -1908,9 +1903,6 @@ def query(
feature_module,
output_hook,
params,
preview_columns,
preview_limit,
preview_offset,
save,
job_id,
)
Expand Down Expand Up @@ -1940,24 +1932,18 @@ def query(
)

try:
response = json.loads(response_text)
result = json.loads(response_text)
except ValueError:
response = {}
exec_result = ExecutionResult(**response)
result = None

dataset: Optional[DatasetRecord] = None
version: Optional[int] = None
if save:
dataset, version = self.save_result(
query_script, exec_result, output, version, job_id
query_script, result, output, version, job_id
)

return QueryResult(
dataset=dataset,
version=version,
output=output,
preview=exec_result.preview,
)
return QueryResult(dataset=dataset, version=version, output=output)

def run_query(
self,
Expand All @@ -1969,9 +1955,6 @@ def run_query(
feature_module: str,
output_hook: Callable[[str], None],
params: Optional[dict[str, str]],
preview_columns: Optional[list[str]],
preview_limit: int,
preview_offset: int,
save: bool,
job_id: Optional[str],
) -> tuple[list[str], subprocess.Popen, str]:
Expand Down Expand Up @@ -2009,13 +1992,6 @@ def run_query(
{
"DATACHAIN_QUERY_PARAMS": json.dumps(params or {}),
"PYTHONPATH": os.getcwd(), # For local imports
"DATACHAIN_QUERY_PREVIEW_ARGS": json.dumps(
{
"limit": preview_limit,
"offset": preview_offset,
"columns": preview_columns,
}
),
"DATACHAIN_QUERY_SAVE": "1" if save else "",
"PYTHONUNBUFFERED": "1",
"DATACHAIN_OUTPUT_FD": str(handle),
Expand Down Expand Up @@ -2046,12 +2022,12 @@ def run_query(
return lines, proc, response_text

def save_result(self, query_script, exec_result, output, version, job_id):
if not exec_result.dataset:
if not exec_result:
raise QueryScriptDatasetNotFound(
"No dataset found after running Query script",
output=output,
)
name, version = exec_result.dataset
name, version = exec_result
# finding returning dataset
try:
dataset = self.get_dataset(name)
Expand Down
17 changes: 1 addition & 16 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
"N defaults to the CPU count."
),
)
add_show_args(query_parser)
query_parser.add_argument(
"-p",
"--param",
Expand Down Expand Up @@ -811,14 +810,9 @@ def query(
catalog: "Catalog",
script: str,
parallel: Optional[int] = None,
limit: int = 10,
offset: int = 0,
columns: Optional[list[str]] = None,
no_collapse: bool = False,
params: Optional[dict[str, str]] = None,
) -> None:
from datachain.data_storage import JobQueryType, JobStatus
from datachain.utils import show_records

with open(script, encoding="utf-8") as f:
script_content = f.read()
Expand All @@ -839,12 +833,9 @@ def query(
)

try:
result = catalog.query(
catalog.query(
script_content,
python_executable=python_executable,
preview_limit=limit,
preview_offset=offset,
preview_columns=columns,
capture_output=False,
params=params,
job_id=job_id,
Expand All @@ -861,8 +852,6 @@ def query(
raise
catalog.metastore.set_job_status(job_id, JobStatus.COMPLETE)

show_records(result.preview, collapse_columns=not no_collapse)


def clear_cache(catalog: "Catalog"):
catalog.cache.clear()
Expand Down Expand Up @@ -1037,10 +1026,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
catalog,
args.script,
parallel=args.parallel,
limit=args.limit,
offset=args.offset,
columns=args.columns,
no_collapse=args.no_collapse,
params=args.param,
)
elif args.command == "apply-udf":
Expand Down
57 changes: 8 additions & 49 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import contextlib
import datetime
import inspect
import json
import logging
Expand Down Expand Up @@ -1724,53 +1723,6 @@ def _get_output_fd_for_write() -> Union[str, int]:
return msvcrt.open_osfhandle(int(handle), os.O_WRONLY) # type: ignore[attr-defined]


@attrs.define
class ExecutionResult:
preview: list[dict] = attrs.field(factory=list)
dataset: Optional[tuple[str, int]] = None


def _send_result(dataset_query: DatasetQuery) -> None:
class JSONSerialize(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (datetime.datetime, datetime.date)):
return obj.isoformat()
if isinstance(obj, bytes):
return list(obj[:1024])
return super().default(obj)

try:
preview_args: dict[str, Any] = json.loads(
os.getenv("DATACHAIN_QUERY_PREVIEW_ARGS", "")
)
except ValueError:
preview_args = {}

columns = preview_args.get("columns") or []

if type(dataset_query) is DatasetQuery:
preview_query = dataset_query.select(*columns)
else:
preview_query = dataset_query.select(*columns, _sys=False)

preview_query = preview_query.limit(preview_args.get("limit", 10)).offset(
preview_args.get("offset", 0)
)

dataset: Optional[tuple[str, int]] = None
if dataset_query.attached:
assert dataset_query.name, "Dataset name should be provided"
assert dataset_query.version, "Dataset version should be provided"
dataset = dataset_query.name, dataset_query.version

preview = preview_query.to_db_records()
result = ExecutionResult(preview, dataset)
data = attrs.asdict(result)

with open(_get_output_fd_for_write(), mode="w") as f:
json.dump(data, f, cls=JSONSerialize)


def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:
"""
Wrapper function that wraps the last statement of user query script.
Expand All @@ -1791,5 +1743,12 @@ def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:
name = catalog.generate_query_dataset_name()
dataset_query = dataset_query.save(name)

_send_result(dataset_query)
dataset: Optional[tuple[str, int]] = None
if dataset_query.attached:
assert dataset_query.name, "Dataset name should be provided"
assert dataset_query.version, "Dataset version should be provided"
dataset = dataset_query.name, dataset_query.version

with open(_get_output_fd_for_write(), mode="w") as f:
json.dump(dataset, f)
return dataset_query
2 changes: 1 addition & 1 deletion tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def mock_popen_dataset_created(

_, w = mock_os_pipe
with open(w, mode="w", closefd=False) as f:
f.write(json.dumps({"dataset": (ds_name, ds_version)}))
f.write(json.dumps((ds_name, ds_version)))

mock_popen.configure_mock(stdout=io.StringIO("user log 1\nuser log 2"))
yield mock_popen
Expand Down
13 changes: 5 additions & 8 deletions tests/func/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,10 @@ def test_query_cli(cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath,
filepath = tmp_path / "query_script.py"
filepath.write_text(query_script)

query(catalog, str(filepath), columns=["name"])
captured = capsys.readouterr()

header, *rows = captured.out.splitlines()
assert header.strip() == "name"
name_rows = {row.split()[1] for row in rows}
assert name_rows == {"cat1", "cat2", "description", "dog1", "dog2", "dog3", "dog4"}
query(catalog, str(filepath))
out, err = capsys.readouterr()
assert not out
assert not err

dataset = catalog.get_dataset("my-ds")
assert dataset
Expand Down Expand Up @@ -152,7 +149,7 @@ def test_query_cli_no_dataset_returned(
QueryScriptRunError,
match="Last line in a script was not an instance of DataChain",
):
query(catalog, str(filepath), columns=["name"])
query(catalog, str(filepath))

latest_job = get_latest_job(catalog.metastore)
assert latest_job
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/feature_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ class Embedding(BaseModel):
.limit(5)
.map(emd=lambda file: Embedding(value=512), output=Embedding)
)

ds.select("file.path", "emd.value").show(limit=5, flatten=True)
ds.save(ds_name)
6 changes: 3 additions & 3 deletions tests/test_query_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,17 @@
"datachain",
"query",
os.path.join(tests_dir, "scripts", "feature_class.py"),
"--columns",
"file.path,emd.value",
),
"expected_rows": dedent(
"""
file__path emd__value
file.path emd.value
0 dogs-and-cats/cat.1.jpg 512.0
1 dogs-and-cats/cat.10.jpg 512.0
2 dogs-and-cats/cat.100.jpg 512.0
3 dogs-and-cats/cat.1000.jpg 512.0
4 dogs-and-cats/cat.1001.jpg 512.0
[Limited by 5 rows]
"""
),
},
Expand Down

0 comments on commit 2952eb1

Please sign in to comment.