Skip to content

Commit

Permalink
Fix notebook sources with NotebookLinter.apply (#3693)
Browse files Browse the repository at this point in the history
## Changes
Add notebook fixing to `NotebookLinter`
- Implement `Notebook.apply` 
- Call `Notebook.apply` from `FileLiner.apply` for `Notebook` source
container
- Remove legacy `NotebookMigrator`
- Introduce `PythonLinter` to run apply on a AST tree
- Allow to 

### Linked issues

Progresses #3514
Breaks up #3520

### Functionality

- [x] modified existing command: `databricks labs ucx
migrate-local-code`

### Tests

- [x] manually tested
- [x] modified and added unit tests
- [x] modified and added integration tests
  • Loading branch information
JCZuurmond authored Feb 20, 2025
1 parent 01c4263 commit eb6009a
Show file tree
Hide file tree
Showing 16 changed files with 561 additions and 135 deletions.
38 changes: 38 additions & 0 deletions src/databricks/labs/ucx/github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from enum import Enum
import urllib.parse


DOCS_URL = "https://databrickslabs.github.io/ucx/docs/"
GITHUB_URL = "https://github.com/databrickslabs/ucx"


class IssueType(Enum):
"""The issue type"""

FEATURE = "Feature"
BUG = "Bug"
TASK = "Task"


def construct_new_issue_url(
issue_type: IssueType,
title: str,
body: str,
*,
labels: set[str] | None = None,
) -> str:
"""Construct a new issue URL.
References:
- https://docs.github.com/en/issues/tracking-your-work-with-issues/using-issues/creating-an-issue#creating-an-issue-from-a-url-query
"""
labels = labels or set()
labels.add("needs-triage")
parameters = {
"type": issue_type.value,
"title": title,
"body": body,
"labels": ",".join(sorted(labels)),
}
query = "&".join(f"{key}={urllib.parse.quote_plus(value)}" for key, value in parameters.items())
return f"{GITHUB_URL}/issues/new?{query}"
5 changes: 3 additions & 2 deletions src/databricks/labs/ucx/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from databricks.sdk.useragent import with_extra

from databricks.labs.ucx.github import DOCS_URL
from databricks.labs.ucx.__about__ import __version__
from databricks.labs.ucx.assessment.azure import AzureServicePrincipalInfo
from databricks.labs.ucx.assessment.clusters import ClusterInfo, PolicyInfo
Expand Down Expand Up @@ -218,7 +219,7 @@ def run(
if isinstance(err.__cause__, RequestsConnectionError):
logger.warning(
f"Cannot connect with {self.workspace_client.config.host} see "
f"https://github.com/databrickslabs/ucx#network-connectivity-issues for help: {err}"
f"{DOCS_URL}#reference/common_challenges/#network-connectivity-issues for help: {err}"
)
raise err
return config
Expand Down Expand Up @@ -573,7 +574,7 @@ def _create_database(self):
if "Unable to load AWS credentials from any provider in the chain" in str(err):
msg = (
"The UCX installation is configured to use external metastore. There is issue with the external metastore connectivity.\n"
"Please check the UCX installation instruction https://github.com/databrickslabs/ucx?tab=readme-ov-file#prerequisites"
f"Please check the UCX installation instruction {DOCS_URL}/installation "
"and re-run installation.\n"
"Please Follow the Below Command to uninstall and Install UCX\n"
"UCX Uninstall: databricks labs uninstall ucx.\n"
Expand Down
5 changes: 3 additions & 2 deletions src/databricks/labs/ucx/source_code/known.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from databricks.labs.blueprint.entrypoint import get_logger

from databricks.labs.ucx.github import GITHUB_URL
from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationIndex
from databricks.labs.ucx.source_code.base import Advice, CurrentSessionState
from databricks.labs.ucx.source_code.graph import (
Expand All @@ -24,6 +25,7 @@
from databricks.labs.ucx.source_code.path_lookup import PathLookup

logger = logging.getLogger(__name__)
KNOWN_URL = f"{GITHUB_URL}/blob/main/src/databricks/labs/ucx/source_code/known.json"

"""
Known libraries that are not in known.json
Expand Down Expand Up @@ -282,10 +284,9 @@ class KnownDependency(Dependency):
"""A dependency for known libraries, see :class:KnownList."""

def __init__(self, module_name: str, problems: list[KnownProblem]):
known_url = "https://github.com/databrickslabs/ucx/blob/main/src/databricks/labs/ucx/source_code/known.json"
# Note that Github does not support navigating JSON files, hence the #<module_name> does nothing.
# https://docs.github.com/en/repositories/working-with-files/using-files/navigating-code-on-github
super().__init__(KnownLoader(), Path(f"{known_url}#{module_name}"), inherits_context=False)
super().__init__(KnownLoader(), Path(f"{KNOWN_URL}#{module_name}"), inherits_context=False)
self._module_name = module_name
self.problems = problems

Expand Down
29 changes: 29 additions & 0 deletions src/databricks/labs/ucx/source_code/linters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,35 @@ def lint(self, code: str) -> Iterable[Advice]:
def lint_tree(self, tree: Tree) -> Iterable[Advice]: ...


class PythonFixer(Fixer):
"""Fix python source code."""

def apply(self, code: str) -> str:
"""Apply the changes to Python source code.
The source code is parsed into an AST tree, and the fixes are applied
to the tree.
"""
maybe_tree = MaybeTree.from_source_code(code)
if maybe_tree.failure:
# Fixing does not yield parse failures, linting does
logger.warning(f"Parsing source code resulted in failure `{maybe_tree.failure}`: {code}")
return code
assert maybe_tree.tree is not None
tree = self.apply_tree(maybe_tree.tree)
return tree.node.as_string()

@abstractmethod
def apply_tree(self, tree: Tree) -> Tree:
"""Apply the fixes to the AST tree.
For Python, the fixes are applied to a Tree so that we
- Can chain multiple fixers without transpiling back and forth between
source code and AST tree
- Can extend the tree with (brought into scope) nodes, e.g. to add imports
"""


class DfsaPyCollector(DfsaCollector, ABC):

def collect_dfsas(self, source_code: str) -> Iterable[DirectFsAccess]:
Expand Down
86 changes: 43 additions & 43 deletions src/databricks/labs/ucx/source_code/linters/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from databricks.labs.ucx.source_code.files import LocalFile
from databricks.labs.ucx.source_code.graph import Dependency
from databricks.labs.ucx.source_code.known import KnownDependency
from databricks.labs.ucx.source_code.linters.base import PythonLinter
from databricks.labs.ucx.source_code.linters.base import PythonFixer, PythonLinter
from databricks.labs.ucx.source_code.linters.context import LinterContext
from databricks.labs.ucx.source_code.linters.imports import SysPathChange, UnresolvedPath
from databricks.labs.ucx.source_code.notebooks.cells import (
Expand All @@ -26,7 +26,6 @@
RunCell,
RunCommand,
)
from databricks.labs.ucx.source_code.notebooks.loaders import NotebookLoader
from databricks.labs.ucx.source_code.notebooks.magic import MagicLine
from databricks.labs.ucx.source_code.notebooks.sources import Notebook
from databricks.labs.ucx.source_code.path_lookup import PathLookup
Expand All @@ -42,7 +41,11 @@ class NotebookLinter:
"""

def __init__(
self, notebook: Notebook, path_lookup: PathLookup, context: LinterContext, parent_tree: Tree | None = None
self,
notebook: Notebook,
path_lookup: PathLookup,
context: LinterContext,
parent_tree: Tree | None = None,
):
self._context: LinterContext = context
self._path_lookup = path_lookup
Expand Down Expand Up @@ -76,6 +79,37 @@ def lint(self) -> Iterable[Advice]:
)
return

def apply(self) -> None:
"""Apply changes to the notebook."""
maybe_tree = self._parse_notebook(self._notebook, parent_tree=self._parent_tree)
if maybe_tree and maybe_tree.failure:
logger.warning("Failed to parse the notebook, run linter for more details.")
return
for cell in self._notebook.cells:
try:
linter = self._context.linter(cell.language.language)
except ValueError: # Language is not supported (yet)
continue
fixed_code = cell.original_code # For default fixing
tree = self._python_tree_cache.get((self._notebook.path, cell)) # For Python fixing
is_python_cell = isinstance(cell, PythonCell)
if is_python_cell and tree:
advices = cast(PythonLinter, linter).lint_tree(tree)
else:
advices = linter.lint(cell.original_code)
for advice in advices:
fixer = self._context.fixer(cell.language.language, advice.code)
if not fixer:
continue
if is_python_cell and tree:
# By calling `apply_tree` instead of `apply`, we chain fixes on the same tree
tree = cast(PythonFixer, fixer).apply_tree(tree)
else:
fixed_code = fixer.apply(fixed_code)
cell.migrated_code = tree.node.as_string() if tree else fixed_code
self._notebook.back_up_original_and_flush_migrated_code()
return

def _parse_notebook(self, notebook: Notebook, *, parent_tree: Tree) -> MaybeTree | None:
"""Parse a notebook by parsing its cells.
Expand Down Expand Up @@ -264,50 +298,16 @@ def apply(self) -> None:
source_container = self._dependency.load(self._path_lookup)
if isinstance(source_container, LocalFile):
self._apply_file(source_container)
elif isinstance(source_container, Notebook):
self._apply_notebook(source_container)

def _apply_file(self, local_file: LocalFile) -> None:
"""Apply changes to a local file."""
fixed_code = self._context.apply_fixes(local_file.language, local_file.original_code)
local_file.migrated_code = fixed_code
local_file.back_up_original_and_flush_migrated_code()


class NotebookMigrator:
def __init__(self, languages: LinterContext):
# TODO: move languages to `apply`
self._languages = languages

def revert(self, path: Path) -> bool:
backup_path = path.with_suffix(".bak")
if not backup_path.exists():
return False
return path.write_text(backup_path.read_text()) > 0

def apply(self, path: Path) -> bool:
if not path.exists():
return False
dependency = Dependency(NotebookLoader(), path)
# TODO: the interface for this method has to be changed
lookup = PathLookup.from_sys_path(Path.cwd())
container = dependency.load(lookup)
assert isinstance(container, Notebook)
return self._apply(container)

def _apply(self, notebook: Notebook) -> bool:
changed = False
for cell in notebook.cells:
# %run is not a supported language, so this needs to come first
if isinstance(cell, RunCell):
# TODO migration data, see https://github.com/databrickslabs/ucx/issues/1327
continue
if not self._languages.is_supported(cell.language.language):
continue
migrated_code = self._languages.apply_fixes(cell.language.language, cell.original_code)
if migrated_code != cell.original_code:
cell.migrated_code = migrated_code
changed = True
if changed:
# TODO https://github.com/databrickslabs/ucx/issues/1327 store 'migrated' status
notebook.path.replace(notebook.path.with_suffix(".bak"))
notebook.path.write_text(notebook.to_migrated_code())
return changed
def _apply_notebook(self, notebook: Notebook) -> None:
"""Apply changes to a notebook."""
notebook_linter = NotebookLinter(notebook, self._path_lookup, self._context, self._inherited_tree)
notebook_linter.apply()
45 changes: 22 additions & 23 deletions src/databricks/labs/ucx/source_code/linters/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import TypeVar

from astroid import Attribute, Call, Const, Name, NodeNG # type: ignore

from databricks.labs.ucx.github import IssueType, construct_new_issue_url
from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationIndex, TableMigrationStatus
from databricks.labs.ucx.source_code.base import (
Advice,
Expand All @@ -20,6 +22,7 @@
from databricks.labs.ucx.source_code.linters.base import (
SqlLinter,
Fixer,
PythonFixer,
PythonLinter,
DfsaPyCollector,
TablePyCollector,
Expand All @@ -33,7 +36,6 @@
from databricks.labs.ucx.source_code.linters.from_table import FromTableSqlLinter
from databricks.labs.ucx.source_code.python.python_ast import (
MatchingVisitor,
MaybeTree,
Tree,
TreeHelper,
)
Expand Down Expand Up @@ -170,8 +172,16 @@ def lint(

def apply(self, from_table: FromTableSqlLinter, index: TableMigrationIndex, node: Call) -> None:
table_arg = self._get_table_arg(node)
assert isinstance(table_arg, Const)
# TODO locate constant when value is inferred
if not isinstance(table_arg, Const):
# TODO: https://github.com/databrickslabs/ucx/issues/3695
source_code = node.as_string()
body = (
"# Desired behaviour\n\nAutofix following Python code\n\n"
f"``` python\nTODO: Add relevant source code\n{source_code}\n```"
)
url = construct_new_issue_url(IssueType.FEATURE, "Autofix the following Python code", body)
logger.warning(f"Cannot fix the following Python code: {source_code}. Please report this issue at {url}")
return
info = UsedTable.parse(table_arg.value, from_table.schema)
dst = self._find_dest(index, info)
if dst is not None:
Expand Down Expand Up @@ -393,7 +403,7 @@ def matchers(self) -> dict[str, _TableNameMatcher]:
return self._matchers


class SparkTableNamePyLinter(PythonLinter, Fixer, TablePyCollector):
class SparkTableNamePyLinter(PythonLinter, PythonFixer, TablePyCollector):
"""Linter for table name references in PySpark
Examples:
Expand Down Expand Up @@ -427,21 +437,15 @@ def lint_tree(self, tree: Tree) -> Iterable[Advice]:
assert isinstance(node, Call)
yield from matcher.lint(self._from_table, self._index, self._session_state, node)

def apply(self, code: str) -> str:
maybe_tree = MaybeTree.from_source_code(code)
if not maybe_tree.tree:
assert maybe_tree.failure is not None
logger.warning(maybe_tree.failure.message)
return code
tree = maybe_tree.tree
# we won't be doing it like this in production, but for the sake of the example
def apply_tree(self, tree: Tree) -> Tree:
"""Apply the fixes to the AST tree."""
for node in tree.walk():
matcher = self._find_matcher(node)
if matcher is None:
continue
assert isinstance(node, Call)
matcher.apply(self._from_table, self._index, node)
return tree.node.as_string()
return tree

def _find_matcher(self, node: NodeNG) -> _TableNameMatcher | None:
if not isinstance(node, Call):
Expand Down Expand Up @@ -476,7 +480,7 @@ def _visit_call_nodes(cls, tree: Tree) -> Iterable[tuple[Call, NodeNG]]:
yield call_node, query


class _SparkSqlPyLinter(_SparkSqlAnalyzer, PythonLinter, Fixer):
class _SparkSqlPyLinter(_SparkSqlAnalyzer, PythonLinter, PythonFixer):
"""Linter for SparkSQL used within PySpark."""

def __init__(self, sql_linter: SqlLinter, sql_fixer: Fixer | None):
Expand All @@ -503,23 +507,18 @@ def lint_tree(self, tree: Tree) -> Iterable[Advice]:
code = self.diagnostic_code
yield dataclasses.replace(advice.replace_from_node(call_node), code=code)

def apply(self, code: str) -> str:
def apply_tree(self, tree: Tree) -> Tree:
"""Apply the fixes to the AST tree."""
if not self._sql_fixer:
return code
maybe_tree = MaybeTree.from_source_code(code)
if maybe_tree.failure:
logger.warning(maybe_tree.failure.message)
return code
assert maybe_tree.tree is not None
tree = maybe_tree.tree
return tree
for _call_node, query in self._visit_call_nodes(tree):
if not isinstance(query, Const) or not isinstance(query.value, str):
continue
# TODO avoid applying same fix multiple times
# this requires changing 'apply' API in order to check advice fragment location
new_query = self._sql_fixer.apply(query.value)
query.value = new_query
return tree.node.as_string()
return tree


class FromTableSqlPyLinter(_SparkSqlPyLinter):
Expand Down
Loading

0 comments on commit eb6009a

Please sign in to comment.