Skip to content

Commit

Permalink
several fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dannymeijer committed Sep 6, 2024
1 parent d368b3b commit 2746fdf
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ all-tests:
spark-tests:
@echo "\033[1mRunning Spark tests:\033[0m\n\033[35m This will run the Spark test suite against all specified environments\033[0m"
@echo "\033[1;31mWARNING:\033[0;33m This may take upward of 20-30 minutes to complete!\033[0m"
@hatch test -m spark --no-header --no-summary
@hatch test -m spark --no-header
.PHONY: non-spark-tests ## testing - Run non-spark tests in ALL environments
non-spark-tests:
@echo "\033[1mRunning non-Spark tests:\033[0m\n\033[35m This will run the non-Spark test suite against all specified environments\033[0m"
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ async_http = [
box = ["boxsdk[jwt]==3.8.1"]
pandas = ["pandas>=1.3", "setuptools", "numpy<2.0.0"]
pyspark = ["pyspark>=3.2.0", "pyarrow>13"]
spark_connect = ["pyspark[connect]>=3.5,<4.0"]
se = ["spark-expectations>=2.1.0"]
# SFTP dependencies in to_csv line_iterator
sftp = ["paramiko>=2.6.0"]
Expand Down Expand Up @@ -269,7 +270,7 @@ version = ["pyspark33", "pyspark34", "pyspark35"]

[[tool.hatch.envs.hatch-test.matrix]]
python = ["3.11", "3.12"]
version = ["pyspark35"]
version = ["pyspark35", "pyspark35connect"]

[tool.hatch.envs.hatch-test.overrides]
matrix.version.extra-dependencies = [
Expand All @@ -291,6 +292,9 @@ matrix.version.extra-dependencies = [
{ value = "pyspark>=3.5,<3.6", if = [
"pyspark35",
] },
{ value = "pyspark[connect]>=3.5,<3.6", if = [
"pyspark35connect",
] },
]

name.".*".env-vars = [
Expand All @@ -301,7 +305,7 @@ name.".*".env-vars = [
]

[tool.pytest.ini_options]
addopts = "-q --color=yes --order-scope=module"
addopts = "--color=yes --order-scope=module"
log_level = "CRITICAL"
testpaths = ["tests"]
markers = [
Expand Down
19 changes: 14 additions & 5 deletions src/koheesio/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import Field

from pyspark.sql import Column
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql import DataFrame as _SparkDataFrame
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

Expand All @@ -21,12 +21,21 @@

from koheesio import Step, StepOutput
from koheesio.spark.utils import get_spark_minor_version
from koheesio.logger import warn

if get_spark_minor_version() >= 3.5:
from pyspark.sql.connect.session import DataFrame as SparkConnectDataFrame

DataFrame = Union[SparkDataFrame, SparkConnectDataFrame]
DataFrame = _SparkDataFrame
if get_spark_minor_version() >= 3.5:
try:
from pyspark.sql.connect.session import DataFrame as _SparkConnectDataFrame
DataFrame = Union[_SparkDataFrame, _SparkConnectDataFrame]
except ImportError:
warn(
"Spark Connect is not available for use. If needed, please install the required package "
"'koheesio[spark-connect]'."
)

__all__ = ["SparkStep", "DataFrame", "current_timestamp_utc"]

class SparkStep(Step, ABC):
"""Base class for a Spark step
Expand All @@ -47,7 +56,7 @@ def spark(self) -> Optional[SparkSession]:
return SparkSession.getActiveSession()


# TODO: Move to spark/functions/__init__.py after reorganizing the code
# TODO: Move to spark/utils.py after reorganizing the code
def current_timestamp_utc(spark: SparkSession) -> Column:
"""Get the current timestamp in UTC"""
return F.to_utc_timestamp(F.current_timestamp(), spark.conf.get("spark.sql.session.timeZone"))

0 comments on commit 2746fdf

Please sign in to comment.