-
Notifications
You must be signed in to change notification settings - Fork 199
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
549 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
[build-system] | ||
requires = ["setuptools>=42", "wheel"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.poetry] | ||
name = "spark_upgrade" | ||
version = "0.0.1" | ||
description = "Rules to migrate 'scaletest'" | ||
# Add any other metadata you need | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.9" | ||
polyglot_piranha = "*" | ||
|
||
[tool.poetry.dev-dependencies] | ||
pytest = "7.4.x" | ||
|
||
[tool.poetry.scripts."scala_test"] | ||
main = "spark_upgrade.main:main" | ||
|
||
[tool.poetry.scripts."pytest"] | ||
main = "pytest" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# `Spark Upgrade` Plugin (WIP) | ||
|
||
Upgrades your codebase to Spark 3.3 | ||
|
||
|
||
Currently, it updates to [v.3.3](https://spark.apache.org/releases/spark-release-3-3-0.html) only. | ||
Supported rewrites: | ||
* `CalendarInterval` -> `DateTimeConstants` | ||
|
||
|
||
|
||
## Usage: | ||
|
||
Clone the repository - `git clone https://github.com/uber/piranha.git` | ||
|
||
Install the dependencies - `pip3 install -r plugins/spark_upgrade/requirements.txt` | ||
|
||
Run the tool - `python3 plugins/spark_upgrade/main.py -h` | ||
|
||
CLI: | ||
``` | ||
usage: main.py [-h] --path_to_codebase PATH_TO_CODEBASE | ||
Updates the codebase to use a new version of `Spark 3.3`. | ||
options: | ||
-h, --help show this help message and exit | ||
--path_to_codebase PATH_TO_CODEBASE | ||
Path to the codebase directory. | ||
``` | ||
|
||
## Test | ||
``` | ||
pytest plugins/ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright (c) 2023 Uber Technologies, Inc. | ||
|
||
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file | ||
# except in compliance with the License. You may obtain a copy of the License at | ||
# <p>http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# <p>Unless required by applicable law or agreed to in writing, software distributed under the | ||
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any, Union, List, Dict | ||
from polyglot_piranha import PiranhaArguments, execute_piranha, Rule, RuleGraph, OutgoingEdges | ||
|
||
|
||
class ExecutePiranha(ABC): | ||
''' | ||
This abstract class implements the higher level strategy for | ||
applying a specific polyglot piranha configuration i.e. rules/edges. | ||
''' | ||
|
||
def __init__(self, paths_to_codebase: List[str], language: str, substitutions: Dict[str, str], dry_run=False, allow_dirty_ast=False): | ||
self.paths_to_codebase = paths_to_codebase | ||
self.language = language | ||
self.substitutions = substitutions | ||
self.dry_run = dry_run | ||
self.allow_dirty_ast = allow_dirty_ast | ||
|
||
def __call__(self) -> dict: | ||
piranha_args = self.get_piranha_arguments() | ||
self.summaries = execute_piranha(piranha_args) | ||
|
||
output = self.summaries_to_custom_dict(self.summaries) | ||
success = True | ||
|
||
if not output: | ||
success = False | ||
output = {} | ||
output[self.step_name()] = success | ||
return output | ||
|
||
@abstractmethod | ||
def step_name(self) -> str: | ||
''' | ||
The overriding method should return the name of the strategy. | ||
''' | ||
... | ||
|
||
@abstractmethod | ||
def get_rules(self) -> List[Rule]: | ||
''' | ||
The list of rules. | ||
''' | ||
... | ||
|
||
def get_edges(self) -> List[OutgoingEdges]: | ||
''' | ||
The list of edges. | ||
''' | ||
return [] | ||
|
||
def get_rule_graph(self) -> RuleGraph: | ||
''' | ||
Strategy to construct a rule graph from rules/edges. | ||
''' | ||
return RuleGraph(rules=self.get_rules(), edges=self.get_edges()) | ||
|
||
def path_to_configuration(self) -> Union[None, str]: | ||
''' | ||
Path to rules/edges toml file (incase rule graph is not specified). | ||
''' | ||
return None | ||
|
||
def get_piranha_arguments(self) -> PiranhaArguments: | ||
rg = self.get_rule_graph() | ||
path_to_conf = self.path_to_configuration() | ||
if rg.rules and path_to_conf: | ||
raise Exception( | ||
"You have provided a rule graph and path to configurations. Do not provide both.") | ||
if not rg.rules and not path_to_conf: | ||
raise Exception("You have neither provided a rule graph nor path to configurations.") | ||
if rg.rules: | ||
return PiranhaArguments( | ||
language=self.language, | ||
paths_to_codebase=self.paths_to_codebase, | ||
substitutions=self.substitutions, | ||
rule_graph=self.get_rule_graph(), | ||
cleanup_comments=True, | ||
dry_run=self.dry_run, | ||
allow_dirty_ast=self.allow_dirty_ast | ||
) | ||
return PiranhaArguments( | ||
language=self.language, | ||
paths_to_codebase=self.paths_to_codebase, | ||
substitutions=self.substitutions, | ||
path_to_configurations=self.path_to_configuration(), | ||
cleanup_comments=True, | ||
dry_run=self.dry_run, | ||
allow_dirty_ast=self.allow_dirty_ast | ||
) | ||
|
||
def get_matches(self, specified_rule: str) -> List[dict]: | ||
""" | ||
This function gets matches for a specified rule. | ||
""" | ||
return [match.matches | ||
for summary in self.summaries | ||
for actual_rule, match in summary.matches if specified_rule == actual_rule] | ||
|
||
@abstractmethod | ||
def summaries_to_custom_dict(self, _) -> Dict[str, Any]: | ||
''' | ||
The overriding method should implement the logic for extracting out the | ||
useful information from the matches/rewrites reported by polyglot piranha into a dict. | ||
''' | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) 2023 Uber Technologies, Inc. | ||
|
||
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file | ||
# except in compliance with the License. You may obtain a copy of the License at | ||
# <p>http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# <p>Unless required by applicable law or agreed to in writing, software distributed under the | ||
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
|
||
from update_calendar_interval import UpdateCalendarInterval | ||
|
||
|
||
def _parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description="Updates the codebase to use a new version of `spark3`" | ||
) | ||
parser.add_argument( | ||
"--path_to_codebase", | ||
required=True, | ||
help="Path to the codebase directory.", | ||
) | ||
parser.add_argument( | ||
"--new_version", | ||
default="3.3", | ||
help="Version of `Spark` to update to.", | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = _parse_args() | ||
if args.new_version == "3.3": | ||
upgrade_to_spark_3_3(args.path_to_codebase) | ||
|
||
|
||
def upgrade_to_spark_3_3(path_to_codebase): | ||
update_calendar_interval = UpdateCalendarInterval([path_to_codebase]) | ||
summary = update_calendar_interval() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
polyglot-piranha | ||
pytest |
Empty file.
15 changes: 15 additions & 0 deletions
15
plugins/spark_upgrade/tests/resources/expected/sample.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package org.piranha | ||
|
||
import org.apache.spark.sql.catalyst.util.DateTimeConstants | ||
import org.apache.spark.unsafe.types.CalendarInterval | ||
|
||
object CalendarIntervalExample { | ||
def main(args: Array[String]): Unit = { | ||
// Accessing MICROS_PER_SECOND constant | ||
val microsPerSecond = DateTimeConstants.MICROS_PER_SECOND | ||
val microsPerHour = DateTimeConstants.MICROS_PER_HOUR | ||
val fromYearMonthString = DateTimeConstants.fromYearMonthString("1-2") | ||
val fromDayTimeString = DateTimeConstants.fromDayTimeString("1-2") | ||
println(s"Microseconds per Second: $microsPerSecond") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package org.piranha | ||
|
||
import org.apache.spark.unsafe.types.CalendarInterval | ||
|
||
object CalendarIntervalExample { | ||
def main(args: Array[String]): Unit = { | ||
// Accessing MICROS_PER_SECOND constant | ||
val microsPerSecond = CalendarInterval.MICROS_PER_SECOND | ||
val microsPerHour = CalendarInterval.MICROS_PER_HOUR | ||
val fromYearMonthString = CalendarInterval.fromYearMonthString("1-2") | ||
val fromDayTimeString = CalendarInterval.fromDayTimeString("1-2") | ||
println(s"Microseconds per Second: $microsPerSecond") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright (c) 2023 Uber Technologies, Inc. | ||
|
||
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file | ||
# except in compliance with the License. You may obtain a copy of the License at | ||
# <p>http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# <p>Unless required by applicable law or agreed to in writing, software distributed under the | ||
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from pathlib import Path | ||
from os import walk | ||
from tempfile import TemporaryDirectory | ||
|
||
from update_calendar_interval import UpdateCalendarInterval | ||
|
||
FORMAT = "%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s" | ||
logging.basicConfig(format=FORMAT) | ||
logging.getLogger().setLevel(logging.DEBUG) | ||
|
||
|
||
def test_update_CalendarInterval(): | ||
logging.error("Here") | ||
input_codebase = "plugins/spark_upgrade/tests/resources/input/" | ||
expected_codebase = "plugins/spark_upgrade/tests/resources/expected/" | ||
logging.info("Here") | ||
with TemporaryDirectory() as temp_dir: | ||
tp = temp_dir | ||
logging.info("Here") | ||
copy_dir(input_codebase, tp) | ||
logging.info("Here") | ||
update_calendar_interval = UpdateCalendarInterval([tp]) | ||
summary = update_calendar_interval() | ||
assert summary is not None | ||
assert is_as_expected_files(expected_codebase, tp) | ||
|
||
|
||
def remove_whitespace(input_str): | ||
"""Removes all the whitespace from the string. | ||
Args: | ||
input_str (str): input string | ||
Returns: | ||
str: transformed input strings with no whitespace | ||
""" | ||
return "".join(input_str.split()).strip() | ||
|
||
|
||
def copy_dir(source_dir, dest_dir): | ||
"""Copy files in {source_dir} to {dest_dir} | ||
Properties to note: | ||
* Assumes {dest_dir} is present. | ||
* Overwrites the similarly named files. | ||
Args: | ||
source_dir (str): | ||
dir_name (str): | ||
""" | ||
for root, _, files in walk(source_dir): | ||
src_root = Path(root) | ||
dest_root = Path(dest_dir, src_root.relative_to(source_dir)) | ||
Path(dest_root).mkdir(parents=True, exist_ok=True) | ||
for f in files: | ||
src_file = Path(src_root, f) | ||
dest_file = Path( | ||
dest_root, f.replace(".testjava", ".java").replace(".testkt", ".kt") | ||
) | ||
dest_file.write_text(src_file.read_text()) | ||
|
||
|
||
def is_as_expected_files(path_to_expected, path_to_actual): | ||
for root, _, files in walk(path_to_actual): | ||
actual_root = Path(root) | ||
expected_root = Path(path_to_expected, actual_root.relative_to(path_to_actual)) | ||
for file_name in files: | ||
actual_file = Path(actual_root, file_name) | ||
expected_file = Path(expected_root, file_name) | ||
if not expected_file.exists(): | ||
expected_file = Path( | ||
expected_root, | ||
file_name.replace(".java", ".testjava") | ||
.replace(".kt", ".testkt") | ||
.replace(".swift", ".testswift") | ||
.replace(".go", ".testgo"), | ||
) | ||
_actual_content = actual_file.read_text() | ||
actual_content = remove_whitespace(_actual_content).strip() | ||
expected_content = remove_whitespace(expected_file.read_text()) | ||
if not actual_content and expected_file.exists(): | ||
logging.error(f"Actual content of the file was empty !!!") | ||
return False | ||
if expected_content != actual_content: | ||
logging.error(f"Actual content of the file :\n{_actual_content}") | ||
return False | ||
return True |
Oops, something went wrong.