Skip to content

Commit

Permalink
do not require last statement to be an expression or an instance of D…
Browse files Browse the repository at this point in the history
…atasetQuery (#395)
  • Loading branch information
skshetry authored Sep 5, 2024
1 parent a8d3640 commit a9f77ba
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 71 deletions.
8 changes: 0 additions & 8 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,6 @@ def attach_query_wrapper(self, code_ast):
),
]
code_ast.body[-1:] = new_expressions
else:
raise Exception("Last line in a script was not an expression")
return code_ast

def compile_query_script(
Expand Down Expand Up @@ -1905,12 +1903,6 @@ def query(
return_code=proc.returncode,
output=output,
)
if proc.returncode == QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE:
raise QueryScriptRunError(
"Last line in a script was not an instance of DataChain",
return_code=proc.returncode,
output=output,
)
raise QueryScriptRunError(
f"Query script exited with error code {proc.returncode}",
return_code=proc.returncode,
Expand Down
10 changes: 3 additions & 7 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@
from tqdm import tqdm

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog import (
QUERY_SCRIPT_CANCELED_EXIT_CODE,
QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE,
get_catalog,
)
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
from datachain.data_storage.schema import (
PARTITION_COLUMN_ID,
partition_col_names,
Expand Down Expand Up @@ -1709,14 +1705,14 @@ def save(
return self.__class__(name=name, version=version, catalog=self.catalog)


def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:
def query_wrapper(dataset_query: Any) -> Any:
"""
Wrapper function that wraps the last statement of user query script.
Last statement MUST be instance of DatasetQuery, otherwise script exits with
error code 10
"""
if not isinstance(dataset_query, DatasetQuery):
sys.exit(QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE)
return dataset_query

catalog = dataset_query.catalog
save = bool(os.getenv("DATACHAIN_QUERY_SAVE"))
Expand Down
36 changes: 0 additions & 36 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,42 +971,6 @@ def test_query_subprocess_wrong_return_code(mock_popen, cloud_test_catalog):
assert str(exc_info.value).startswith("Query script exited with error code 1")


def test_query_last_statement_not_expression(mock_popen, cloud_test_catalog):
mock_popen.configure_mock(returncode=10)
catalog = cloud_test_catalog.catalog
src_uri = cloud_test_catalog.src_uri

query_script = f"""
from datachain.query import DatasetQuery, C
ds = DatasetQuery('{src_uri}')
"""

with pytest.raises(QueryScriptCompileError) as exc_info:
catalog.query(query_script)
assert str(exc_info.value).startswith(
"Query script failed to compile, "
"reason: Last line in a script was not an expression"
)


def test_query_last_statement_not_ds_query_instance(mock_popen, cloud_test_catalog):
mock_popen.configure_mock(returncode=10)
catalog = cloud_test_catalog.catalog
src_uri = cloud_test_catalog.src_uri

query_script = f"""
from datachain.query import DatasetQuery, C
ds = DatasetQuery('{src_uri}')
5
"""

with pytest.raises(QueryScriptRunError) as exc_info:
catalog.query(query_script)
assert str(exc_info.value).startswith(
"Last line in a script was not an instance of DataChain"
)


def test_query_dataset_not_returned(mock_popen, cloud_test_catalog):
mock_popen.configure_mock(stdout=io.StringIO("random str"))
catalog = cloud_test_catalog.catalog
Expand Down
31 changes: 11 additions & 20 deletions tests/func/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from datachain.catalog import QUERY_DATASET_PREFIX
from datachain.cli import query
from datachain.data_storage import AbstractDBMetastore, JobQueryType, JobStatus
from datachain.error import QueryScriptRunError
from tests.utils import assert_row_names


Expand Down Expand Up @@ -128,37 +127,29 @@ def test_query_cli(cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath,
assert latest_job[5] == ""


def test_query_cli_no_dataset_returned(
def test_query_cli_without_dataset_query_as_a_last_statement(
cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath, capsys
):
catalog = cloud_test_catalog_tmpfile.catalog
src_uri = cloud_test_catalog_tmpfile.src_uri

query_script = """\
query_script = f"""\
from datachain.query import DatasetQuery
DatasetQuery("test", catalog=catalog)
DatasetQuery({src_uri!r}, catalog=catalog).save("temp")
print("test")
"""
query_script = setup_catalog(query_script, catalog_info_filepath)

filepath = tmp_path / "query_script.py"
filepath.write_text(query_script)

with pytest.raises(
QueryScriptRunError,
match="Last line in a script was not an instance of DataChain",
):
query(catalog, str(filepath))

latest_job = get_latest_job(catalog.metastore)
assert latest_job
result = catalog.query(query_script)
assert result.dataset
assert result.dataset.name == "temp"
assert result.version == 1

assert latest_job[1] == os.path.basename(filepath)
assert latest_job[2] == JobStatus.FAILED
assert latest_job[3] == JobQueryType.PYTHON
assert latest_job[4] == "Last line in a script was not an instance of DataChain"
assert latest_job[5].find("datachain.error.QueryScriptRunError")
out, err = capsys.readouterr()
assert "test" in out
assert not err


@pytest.mark.parametrize(
Expand Down

0 comments on commit a9f77ba

Please sign in to comment.