Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes the usage of escaped JSONPath in incremental cursors in sql_database #2077

Merged
merged 3 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion dlt/common/jsonpath.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Union, List, Any
from typing import Iterable, Union, List, Any, Optional, cast
from itertools import chain

from dlt.common.typing import DictStrAny
Expand Down Expand Up @@ -46,3 +46,48 @@ def resolve_paths(paths: TAnyJsonPath, data: DictStrAny) -> List[str]:
paths = compile_paths(paths)
p: JSONPath
return list(chain.from_iterable((str(r.full_path) for r in p.find(data)) for p in paths))


def is_simple_field_path(path: JSONPath) -> bool:
"""Checks if the given path represents a simple single field name.

Example:
>>> is_simple_field_path(compile_path('id'))
True
>>> is_simple_field_path(compile_path('$.id'))
False
"""
return isinstance(path, JSONPathFields) and len(path.fields) == 1 and path.fields[0] != "*"


def extract_simple_field_name(path: Union[str, JSONPath]) -> Optional[str]:
"""
Extracts a simple field name from a JSONPath if it represents a single field access.
Returns None if the path is complex (contains wildcards, array indices, or multiple fields).

Args:
path: A JSONPath object or string

Returns:
Optional[str]: The field name if path represents a simple field access, None otherwise

Example:
>>> extract_simple_field_name('name')
'name'
>>> extract_simple_field_name('"name"')
'name'
>>> extract_simple_field_name('"na$me"') # Escaped characters are preserved
'na$me'
>>> extract_simple_field_name('"na.me"') # Escaped characters are preserved
'na.me'
>>> extract_simple_field_name('$.name') # Returns None
>>> extract_simple_field_name('$.items[*].name') # Returns None
>>> extract_simple_field_name('*') # Returns None
"""
if isinstance(path, str):
path = compile_path(path)

if is_simple_field_path(path):
return cast(str, path.fields[0])

return None
10 changes: 3 additions & 7 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dlt.common.json import json
from dlt.common.pendulum import pendulum
from dlt.common.typing import TDataItem
from dlt.common.jsonpath import find_values, JSONPathFields, compile_path
from dlt.common.jsonpath import find_values, compile_path, extract_simple_field_name
from dlt.extract.incremental.exceptions import (
IncrementalCursorInvalidCoercion,
IncrementalCursorPathMissing,
Expand Down Expand Up @@ -75,12 +75,8 @@ def __init__(
# compile jsonpath
self._compiled_cursor_path = compile_path(cursor_path)
# for simple column name we'll fallback to search in dict
if (
isinstance(self._compiled_cursor_path, JSONPathFields)
and len(self._compiled_cursor_path.fields) == 1
and self._compiled_cursor_path.fields[0] != "*"
):
self.cursor_path = self._compiled_cursor_path.fields[0]
if simple_field_name := extract_simple_field_name(self._compiled_cursor_path):
self.cursor_path = simple_field_name
self._compiled_cursor_path = None

def compute_unique_value(
Expand Down
11 changes: 10 additions & 1 deletion dlt/sources/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dlt.common.exceptions import MissingDependencyException
from dlt.common.schema import TTableSchemaColumns
from dlt.common.typing import TDataItem, TSortOrder
from dlt.common.jsonpath import extract_simple_field_name

from dlt.extract import Incremental

Expand Down Expand Up @@ -60,8 +61,16 @@ def __init__(
self.query_adapter_callback = query_adapter_callback
self.incremental = incremental
if incremental:
column_name = extract_simple_field_name(incremental.cursor_path)

if column_name is None:
raise ValueError(
f"Cursor path '{incremental.cursor_path}' must be a simple column name (e.g."
" 'created_at')"
)

try:
self.cursor_column = table.c[incremental.cursor_path]
self.cursor_column = table.c[column_name]
except KeyError as e:
raise KeyError(
f"Cursor column '{incremental.cursor_path}' does not exist in table"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ Incremental loading uses a cursor column (e.g., timestamp or auto-incrementing I
1. **Set end_value for backfill**: Set `end_value` if you want to backfill data from a certain range.
1. **Order returned rows**: Set `row_order` to `asc` or `desc` to order returned rows.

:::info Special characters in the cursor column name
If your cursor column name contains special characters (e.g., `$`) you need to escape it when passing it to the `incremental` function. For example, if your cursor column is `example_$column`, you should pass it as `"'example_$column'"` or `'"example_$column"'` to the `incremental` function: `incremental("'example_$column'", initial_value=...)`.
:::

#### Examples

1. **Incremental loading with the resource `sql_table`**.
Expand Down
80 changes: 80 additions & 0 deletions tests/load/sources/sql_database/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,86 @@ class MockIncremental:
assert query.compare(expected)


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
def test_cursor_path_field_name_with_a_special_chars(
sql_source_db: SQLAlchemySourceDB, backend: TableBackend
) -> None:
"""Test that a field name with special characters in cursor_path is handled correctly."""
table = sql_source_db.get_table("chat_message")

# Add a mock column with a special character
special_field_name = "id$field"
if special_field_name not in table.c:
table.append_column(sa.Column(special_field_name, sa.String))

class MockIncremental:
cursor_path = "'id$field'"
last_value = None
end_value = None
row_order = None
on_cursor_value_missing = None

# Should not raise any exception
loader = TableLoader(
sql_source_db.engine,
backend,
table,
table_to_columns(table),
incremental=MockIncremental(), # type: ignore[arg-type]
)
assert loader.cursor_column == table.c[special_field_name]


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
def test_cursor_path_multiple_fields(
sql_source_db: SQLAlchemySourceDB, backend: TableBackend
) -> None:
"""Test that a cursor_path with multiple fields raises a ValueError."""
table = sql_source_db.get_table("chat_message")

class MockIncremental:
cursor_path = "created_at,updated_at"
last_value = None
end_value = None
row_order = None
on_cursor_value_missing = None

with pytest.raises(ValueError) as excinfo:
TableLoader(
sql_source_db.engine,
backend,
table,
table_to_columns(table),
incremental=MockIncremental(), # type: ignore[arg-type]
)
assert "must be a simple column name" in str(excinfo.value)


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
def test_cursor_path_complex_expression(
sql_source_db: SQLAlchemySourceDB, backend: TableBackend
) -> None:
"""Test that a complex JSONPath expression in cursor_path raises a ValueError."""
table = sql_source_db.get_table("chat_message")

class MockIncremental:
cursor_path = "$.users[0].id"
last_value = None
end_value = None
row_order = None
on_cursor_value_missing = None

with pytest.raises(ValueError) as excinfo:
TableLoader(
sql_source_db.engine,
backend,
table,
table_to_columns(table),
incremental=MockIncremental(), # type: ignore[arg-type]
)
assert "must be a simple column name" in str(excinfo.value)


def mock_json_column(field: str) -> TDataItem:
""""""
import pyarrow as pa
Expand Down
Loading