Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
kevtran2 committed Sep 6, 2023
2 parents 15de5b1 + 33edaaf commit bdff029
Show file tree
Hide file tree
Showing 16 changed files with 549 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
Version 0.3.19
-------------
* Fix swift cleanup of statements after return

Version 0.3.18
-------------
* Improve go feature flag cleanup
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
name = "piranha"
description = "Polyglot Piranha is a library for performing structural find and replace with deep cleanup."
version = "0.3.18"
version = "0.3.19"
edition = "2021"
include = ["pyproject.toml", "src/"]
exclude = ["legacy"]
Expand Down
22 changes: 22 additions & 0 deletions plugins/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"

[tool.poetry]
name = "spark_upgrade"
version = "0.0.1"
description = "Rules to migrate 'scaletest'"
# Add any other metadata you need

[tool.poetry.dependencies]
python = "^3.9"
polyglot_piranha = "*"

[tool.poetry.dev-dependencies]
pytest = "7.4.x"

[tool.poetry.scripts."scala_test"]
main = "spark_upgrade.main:main"

[tool.poetry.scripts."pytest"]
main = "pytest"
35 changes: 35 additions & 0 deletions plugins/spark_upgrade/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# `Spark Upgrade` Plugin (WIP)

Upgrades your codebase to Spark 3.3


Currently, it updates to [v.3.3](https://spark.apache.org/releases/spark-release-3-3-0.html) only.
Supported rewrites:
* `CalendarInterval` -> `DateTimeConstants`



## Usage:

Clone the repository - `git clone https://github.com/uber/piranha.git`

Install the dependencies - `pip3 install -r plugins/spark_upgrade/requirements.txt`

Run the tool - `python3 plugins/spark_upgrade/main.py -h`

CLI:
```
usage: main.py [-h] --path_to_codebase PATH_TO_CODEBASE
Updates the codebase to use a new version of `Spark 3.3`.
options:
-h, --help show this help message and exit
--path_to_codebase PATH_TO_CODEBASE
Path to the codebase directory.
```

## Test
```
pytest plugins/
```
117 changes: 117 additions & 0 deletions plugins/spark_upgrade/execute_piranha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2023 Uber Technologies, Inc.

# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0

# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Union, List, Dict
from polyglot_piranha import PiranhaArguments, execute_piranha, Rule, RuleGraph, OutgoingEdges


class ExecutePiranha(ABC):
'''
This abstract class implements the higher level strategy for
applying a specific polyglot piranha configuration i.e. rules/edges.
'''

def __init__(self, paths_to_codebase: List[str], language: str, substitutions: Dict[str, str], dry_run=False, allow_dirty_ast=False):
self.paths_to_codebase = paths_to_codebase
self.language = language
self.substitutions = substitutions
self.dry_run = dry_run
self.allow_dirty_ast = allow_dirty_ast

def __call__(self) -> dict:
piranha_args = self.get_piranha_arguments()
self.summaries = execute_piranha(piranha_args)

output = self.summaries_to_custom_dict(self.summaries)
success = True

if not output:
success = False
output = {}
output[self.step_name()] = success
return output

@abstractmethod
def step_name(self) -> str:
'''
The overriding method should return the name of the strategy.
'''
...

@abstractmethod
def get_rules(self) -> List[Rule]:
'''
The list of rules.
'''
...

def get_edges(self) -> List[OutgoingEdges]:
'''
The list of edges.
'''
return []

def get_rule_graph(self) -> RuleGraph:
'''
Strategy to construct a rule graph from rules/edges.
'''
return RuleGraph(rules=self.get_rules(), edges=self.get_edges())

def path_to_configuration(self) -> Union[None, str]:
'''
Path to rules/edges toml file (incase rule graph is not specified).
'''
return None

def get_piranha_arguments(self) -> PiranhaArguments:
rg = self.get_rule_graph()
path_to_conf = self.path_to_configuration()
if rg.rules and path_to_conf:
raise Exception(
"You have provided a rule graph and path to configurations. Do not provide both.")
if not rg.rules and not path_to_conf:
raise Exception("You have neither provided a rule graph nor path to configurations.")
if rg.rules:
return PiranhaArguments(
language=self.language,
paths_to_codebase=self.paths_to_codebase,
substitutions=self.substitutions,
rule_graph=self.get_rule_graph(),
cleanup_comments=True,
dry_run=self.dry_run,
allow_dirty_ast=self.allow_dirty_ast
)
return PiranhaArguments(
language=self.language,
paths_to_codebase=self.paths_to_codebase,
substitutions=self.substitutions,
path_to_configurations=self.path_to_configuration(),
cleanup_comments=True,
dry_run=self.dry_run,
allow_dirty_ast=self.allow_dirty_ast
)

def get_matches(self, specified_rule: str) -> List[dict]:
"""
This function gets matches for a specified rule.
"""
return [match.matches
for summary in self.summaries
for actual_rule, match in summary.matches if specified_rule == actual_rule]

@abstractmethod
def summaries_to_custom_dict(self, _) -> Dict[str, Any]:
'''
The overriding method should implement the logic for extracting out the
useful information from the matches/rewrites reported by polyglot piranha into a dict.
'''
...
47 changes: 47 additions & 0 deletions plugins/spark_upgrade/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2023 Uber Technologies, Inc.

# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0

# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.

import argparse

from update_calendar_interval import UpdateCalendarInterval


def _parse_args():
parser = argparse.ArgumentParser(
description="Updates the codebase to use a new version of `spark3`"
)
parser.add_argument(
"--path_to_codebase",
required=True,
help="Path to the codebase directory.",
)
parser.add_argument(
"--new_version",
default="3.3",
help="Version of `Spark` to update to.",
)
args = parser.parse_args()
return args


def main():
args = _parse_args()
if args.new_version == "3.3":
upgrade_to_spark_3_3(args.path_to_codebase)


def upgrade_to_spark_3_3(path_to_codebase):
update_calendar_interval = UpdateCalendarInterval([path_to_codebase])
summary = update_calendar_interval()


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions plugins/spark_upgrade/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
polyglot-piranha
pytest
Empty file.
15 changes: 15 additions & 0 deletions plugins/spark_upgrade/tests/resources/expected/sample.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.piranha

import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.unsafe.types.CalendarInterval

object CalendarIntervalExample {
def main(args: Array[String]): Unit = {
// Accessing MICROS_PER_SECOND constant
val microsPerSecond = DateTimeConstants.MICROS_PER_SECOND
val microsPerHour = DateTimeConstants.MICROS_PER_HOUR
val fromYearMonthString = DateTimeConstants.fromYearMonthString("1-2")
val fromDayTimeString = DateTimeConstants.fromDayTimeString("1-2")
println(s"Microseconds per Second: $microsPerSecond")
}
}
14 changes: 14 additions & 0 deletions plugins/spark_upgrade/tests/resources/input/sample.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.piranha

import org.apache.spark.unsafe.types.CalendarInterval

object CalendarIntervalExample {
def main(args: Array[String]): Unit = {
// Accessing MICROS_PER_SECOND constant
val microsPerSecond = CalendarInterval.MICROS_PER_SECOND
val microsPerHour = CalendarInterval.MICROS_PER_HOUR
val fromYearMonthString = CalendarInterval.fromYearMonthString("1-2")
val fromDayTimeString = CalendarInterval.fromDayTimeString("1-2")
println(s"Microseconds per Second: $microsPerSecond")
}
}
95 changes: 95 additions & 0 deletions plugins/spark_upgrade/tests/test_spark_upgrade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2023 Uber Technologies, Inc.

# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0

# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.

import logging
from pathlib import Path
from os import walk
from tempfile import TemporaryDirectory

from update_calendar_interval import UpdateCalendarInterval

FORMAT = "%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s"
logging.basicConfig(format=FORMAT)
logging.getLogger().setLevel(logging.DEBUG)


def test_update_CalendarInterval():
logging.error("Here")
input_codebase = "plugins/spark_upgrade/tests/resources/input/"
expected_codebase = "plugins/spark_upgrade/tests/resources/expected/"
logging.info("Here")
with TemporaryDirectory() as temp_dir:
tp = temp_dir
logging.info("Here")
copy_dir(input_codebase, tp)
logging.info("Here")
update_calendar_interval = UpdateCalendarInterval([tp])
summary = update_calendar_interval()
assert summary is not None
assert is_as_expected_files(expected_codebase, tp)


def remove_whitespace(input_str):
"""Removes all the whitespace from the string.
Args:
input_str (str): input string
Returns:
str: transformed input strings with no whitespace
"""
return "".join(input_str.split()).strip()


def copy_dir(source_dir, dest_dir):
"""Copy files in {source_dir} to {dest_dir}
Properties to note:
* Assumes {dest_dir} is present.
* Overwrites the similarly named files.
Args:
source_dir (str):
dir_name (str):
"""
for root, _, files in walk(source_dir):
src_root = Path(root)
dest_root = Path(dest_dir, src_root.relative_to(source_dir))
Path(dest_root).mkdir(parents=True, exist_ok=True)
for f in files:
src_file = Path(src_root, f)
dest_file = Path(
dest_root, f.replace(".testjava", ".java").replace(".testkt", ".kt")
)
dest_file.write_text(src_file.read_text())


def is_as_expected_files(path_to_expected, path_to_actual):
for root, _, files in walk(path_to_actual):
actual_root = Path(root)
expected_root = Path(path_to_expected, actual_root.relative_to(path_to_actual))
for file_name in files:
actual_file = Path(actual_root, file_name)
expected_file = Path(expected_root, file_name)
if not expected_file.exists():
expected_file = Path(
expected_root,
file_name.replace(".java", ".testjava")
.replace(".kt", ".testkt")
.replace(".swift", ".testswift")
.replace(".go", ".testgo"),
)
_actual_content = actual_file.read_text()
actual_content = remove_whitespace(_actual_content).strip()
expected_content = remove_whitespace(expected_file.read_text())
if not actual_content and expected_file.exists():
logging.error(f"Actual content of the file was empty !!!")
return False
if expected_content != actual_content:
logging.error(f"Actual content of the file :\n{_actual_content}")
return False
return True
Loading

0 comments on commit bdff029

Please sign in to comment.