Skip to content

Commit

Permalink
allows for built-in ast unparse if present
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Nov 25, 2024
1 parent e6b5b02 commit bfd015e
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 16 deletions.
13 changes: 10 additions & 3 deletions dlt/common/reflection/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import ast
import inspect
import astunparse
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Callable

try:
import astunparse

ast_unparse: Callable[[ast.AST], str] = astunparse.unparse
except ImportError:
ast_unparse = ast.unparse # type: ignore[attr-defined, unused-ignore]

from dlt.common.typing import AnyFun

Expand All @@ -25,7 +32,7 @@ def get_literal_defaults(node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) ->
literal_defaults: Dict[str, str] = {}
for arg, default in zip(reversed(args), reversed(defaults)):
if default:
literal_defaults[str(arg.arg)] = astunparse.unparse(default).strip()
literal_defaults[str(arg.arg)] = ast_unparse(default).strip()

return literal_defaults

Expand Down Expand Up @@ -99,7 +106,7 @@ def rewrite_python_script(
script_lines.append(source_script_lines[last_line][last_offset : node.col_offset])

# replace node value
script_lines.append(astunparse.unparse(t_value).strip())
script_lines.append(ast_unparse(t_value).strip())
last_line = node.end_lineno - 1
last_offset = node.end_col_offset

Expand Down
4 changes: 0 additions & 4 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@
from contextlib import contextmanager

from dlt import version

from dlt.common.json import json

from dlt.common.normalizers.naming.naming import NamingConvention
from dlt.common.exceptions import MissingDependencyException

from dlt.common.destination import AnyDestination
from dlt.common.destination.reference import (
SupportsReadableRelation,
Expand Down
3 changes: 1 addition & 2 deletions dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
TPipelineStateDoc,
load_package as current_load_package,
)
from dlt.destinations.sql_client import DBApiCursor, WithSqlClient, SqlClientBase
from dlt.destinations.sql_client import WithSqlClient, SqlClientBase
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import (
FollowupJobRequest,
Expand All @@ -63,7 +63,6 @@
from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration
from dlt.destinations import path_utils
from dlt.destinations.fs_client import FSClientBase
from dlt.destinations.dataset import ReadableDBAPIDataset
from dlt.destinations.utils import verify_schema_merge_disposition

INIT_FILE_NAME = "init"
Expand Down
8 changes: 4 additions & 4 deletions dlt/reflection/script_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import astunparse
from ast import NodeVisitor
from typing import Any, Dict, List
from dlt.common.reflection.utils import find_outer_func_def

from dlt.common.reflection.utils import find_outer_func_def, ast_unparse

import dlt.reflection.names as n

Expand Down Expand Up @@ -68,9 +68,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
for deco in node.decorator_list:
# decorators can be function calls, attributes or names
if isinstance(deco, (ast.Name, ast.Attribute)):
alias_name = astunparse.unparse(deco).strip()
alias_name = ast_unparse(deco).strip()
elif isinstance(deco, ast.Call):
alias_name = astunparse.unparse(deco.func).strip()
alias_name = ast_unparse(deco.func).strip()
else:
raise ValueError(
self.source_segment(deco), type(deco), "Unknown decorator form"
Expand All @@ -87,7 +87,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
def visit_Call(self, node: ast.Call) -> Any:
if self._curr_pass == 2:
# check if this is a call to any of known functions
alias_name = astunparse.unparse(node.func).strip()
alias_name = ast_unparse(node.func).strip()
fn = self.func_aliases.get(alias_name)
if not fn:
# try a fallback to "run" function that may be called on pipeline or source
Expand Down
5 changes: 2 additions & 3 deletions dlt/sources/sql_database/arrow_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

from dlt.common.configuration import with_config
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.libs.pyarrow import (
row_tuples_to_arrow as _row_tuples_to_arrow,
)


@with_config
Expand All @@ -20,6 +17,8 @@ def row_tuples_to_arrow(
is always the case if run within the pipeline. This will generate arrow schema compatible with the destination.
Otherwise generic capabilities are used
"""
from dlt.common.libs.pyarrow import row_tuples_to_arrow as _row_tuples_to_arrow

return _row_tuples_to_arrow(
rows, caps or DestinationCapabilitiesContext.generic_capabilities(), columns, tz
)

0 comments on commit bfd015e

Please sign in to comment.