diff --git a/pyproject.toml b/pyproject.toml index 3edd7ec..493eb78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,17 +70,7 @@ tableau = ["tableauhyperapi>=0.0.19484", "tableauserverclient>=0.25"] # Snowflake dependencies snowflake = ["snowflake-connector-python>=3.12.0"] # Development dependencies -dev = [ - "black", - "isort", - "ruff", - "mypy", - "pylint", - "colorama", - "types-PyYAML", - "types-requests", - -] +dev = ["ruff", "mypy", "pylint", "colorama", "types-PyYAML", "types-requests"] test = [ "chispa", "coverage[toml]", @@ -153,23 +143,19 @@ Run `hatch run` to run scripts in the default environment. # Code Quality To check and format the codebase, we use: - - `black` for code formatting - - `isort` for import sorting (includes colorama for colored output) - - `ruff` for linting. + - `ruff` for linting, formtting and sorting imports - `mypy` for static type checking. - `pylint` for code quality checks. --- There are several ways to run style checks and formatting: - `hatch run black-check` will check the codebase with black without applying fixes. - `hatch run black-fmt` will format the codebase using black. -- `hatch run isort-check` will check the codebase with isort without applying fixes. -- `hatch run isort-fmt` will format the codebase using isort. - `hatch run ruff-check` will check the codebase with ruff without applying fixes. - `hatch run ruff-fmt` will format the codebase using ruff. - `hatch run mypy-check` will check the codebase with mypy. - `hatch run pylint-check` will check the codebase with pylint. - `hatch run check` will run all the above checks (including pylint and mypy). -- `hatch run fmt` or `hatch run fix` will format the codebase using black, isort, and ruff. +- `hatch run fmt` or `hatch run fix` will format the codebase using ruff. - `hatch run lint` will run ruff, mypy, and pylint. # Testing and Coverage @@ -207,22 +193,14 @@ features = [ # TODO: add scripts section based on Makefile # TODO: add bandit # Code Quality commands -black-check = "black --check --diff ." -black-fmt = "black ." -isort-check = "isort . --check --diff --color" -isort-fmt = "isort ." -ruff-check = "ruff check ." -ruff-fmt = "ruff check . --fix" +ruff-fmt = "ruff format --check --diff ." +ruff-fmt-fix = "ruff format ." +ruff-check = "ruff check . --diff" +ruff-check-fix = "ruff check . --fix" mypy-check = "mypy src" pylint-check = "pylint --output-format=colorized -d W0511 src" -check = [ - "- black-check", - "- isort-check", - "- ruff-check", - "- mypy-check", - "- pylint-check", -] -fmt = ["black-fmt", "isort-fmt", "ruff-fmt"] +check = ["- ruff-fmt", "- ruff-check", "- mypy-check", "- pylint-check"] +fmt = ["ruff-fmt-fix", "ruff-check-fix"] fix = "fmt" lint = ["- ruff-fmt", "- mypy-check", "pylint-check"] log-versions = "python --version && {env:HATCH_UV} pip freeze | grep pyspark" @@ -353,6 +331,7 @@ filterwarnings = [ "ignore:'PYARROW_IGNORE_TIMEZONE'.*:UserWarning:pyspark.pandas.*", # pydantic warnings "ignore:A custom validator is returning a value other than `self`.*.*:UserWarning:pydantic.main.*", + "ignore: 79 characters)' -> let Black handle this instead @@ -549,7 +494,6 @@ ignore = [ ] # Unlike Flake8, default to a complexity level of 10. mccabe.max-complexity = 10 - # Allow autofix for all enabled rules (when `--fix` is provided). fixable = [ "A", @@ -602,6 +546,22 @@ unfixable = [] # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +[tool.ruff.lint.isort] +force-to-top = ["__future__", "typing"] +section-order = [ + "future", + "standard-library", + "third-party", + "pydantic", + "pyspark", + "first-party", + "local-folder", +] +sections.pydantic = ["pydantic"] +sections.pyspark = ["pyspark"] +detect-same-package = true +force-sort-within-sections = true + [tool.mypy] python_version = "3.10" files = ["koheesio/**/*.py"] diff --git a/src/koheesio/asyncio/__init__.py b/src/koheesio/asyncio/__init__.py index 093c4a0..3dc63fe 100644 --- a/src/koheesio/asyncio/__init__.py +++ b/src/koheesio/asyncio/__init__.py @@ -2,10 +2,12 @@ This module provides classes for asynchronous steps in the koheesio package. """ -from typing import Dict, Union +from typing import Dict, Optional, Union from abc import ABC from asyncio import iscoroutine +from pydantic import PrivateAttr + from koheesio.steps import Step, StepMetaClass, StepOutput @@ -65,7 +67,9 @@ def merge(self, other: Union[Dict, StepOutput]) -> "AsyncStepOutput": -------- ```python step_output = StepOutput(foo="bar") - step_output.merge({"lorem": "ipsum"}) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} + step_output.merge( + {"lorem": "ipsum"} + ) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Functionally similar to adding two dicts together; like running `{**dict_a, **dict_b}`. @@ -103,4 +107,4 @@ class Output(AsyncStepOutput): This class represents the output of the asyncio step. It inherits from the AsyncStepOutput class. """ - __output__: Output + _output: Optional[Output] = PrivateAttr(default=None) diff --git a/src/koheesio/asyncio/http.py b/src/koheesio/asyncio/http.py index ece14f1..14e4ad5 100644 --- a/src/koheesio/asyncio/http.py +++ b/src/koheesio/asyncio/http.py @@ -4,14 +4,14 @@ from __future__ import annotations +from typing import Any, Dict, List, Optional, Tuple, Union import asyncio import warnings -from typing import Any, Dict, List, Optional, Tuple, Union -import nest_asyncio # type: ignore[import-untyped] -import yarl from aiohttp import BaseConnector, ClientSession, TCPConnector from aiohttp_retry import ExponentialRetry, RetryClient, RetryOptionsBase +import nest_asyncio # type: ignore[import-untyped] +import yarl from pydantic import Field, SecretStr, field_validator, model_validator @@ -54,26 +54,28 @@ class AsyncHttpStep(AsyncStep, ExtraParamsMixin): from yarl import URL from typing import Dict, Any, Union, List, Tuple + # Initialize the AsyncHttpStep async def main(): session = ClientSession() - urls = [URL('https://example.com/api/1'), URL('https://example.com/api/2')] + urls = [URL("https://example.com/api/1"), URL("https://example.com/api/2")] retry_options = ExponentialRetry() connector = TCPConnector(limit=10) - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} step = AsyncHttpStep( client_session=session, url=urls, retry_options=retry_options, connector=connector, - headers=headers + headers=headers, ) # Execute the step - responses_urls= await step.get() + responses_urls = await step.get() return responses_urls + # Run the main function responses_urls = asyncio.run(main()) ``` diff --git a/src/koheesio/context.py b/src/koheesio/context.py index e0b818a..6fe725c 100644 --- a/src/koheesio/context.py +++ b/src/koheesio/context.py @@ -13,10 +13,10 @@ from __future__ import annotations -import re +from typing import Any, Dict, Iterator, Union from collections.abc import Mapping from pathlib import Path -from typing import Any, Dict, Iterator, Union +import re import jsonpickle # type: ignore[import-untyped] import tomli diff --git a/src/koheesio/integrations/box.py b/src/koheesio/integrations/box.py index cd5baab..114fdc0 100644 --- a/src/koheesio/integrations/box.py +++ b/src/koheesio/integrations/box.py @@ -10,17 +10,16 @@ * Application is authorized for the enterprise (Developer Portal - MyApp - Authorization) """ -import datetime -import re from typing import Any, Dict, Optional, Union from abc import ABC from io import BytesIO, StringIO from pathlib import PurePath +import re -import pandas as pd from boxsdk import Client, JWTAuth from boxsdk.object.file import File from boxsdk.object.folder import Folder +import pandas as pd from pyspark.sql.functions import expr, lit from pyspark.sql.types import StructType @@ -475,7 +474,7 @@ def execute(self) -> BoxReaderBase.Output: if len(files) > 0: self.log.info( - f"A total of {len(files)} files, that match the mask '{self.mask}' has been detected in {self.path}." + f"A total of {len(files)} files, that match the mask '{self.filter}' has been detected in {self.path}." f" They will be loaded into Spark Dataframe: {files}" ) else: diff --git a/src/koheesio/integrations/snowflake/test_utils.py b/src/koheesio/integrations/snowflake/test_utils.py index 8b85e97..8ae9ac3 100644 --- a/src/koheesio/integrations/snowflake/test_utils.py +++ b/src/koheesio/integrations/snowflake/test_utils.py @@ -25,7 +25,9 @@ def test_execute(self, mock_query): mock_query.expected_data = [("row1",), ("row2",)] # Act - instance = SnowflakeRunQueryPython(**COMMON_OPTIONS, query=query, account="42") + instance = SnowflakeRunQueryPython( + **COMMON_OPTIONS, query=query, account="42" + ) instance.execute() # Assert diff --git a/src/koheesio/integrations/spark/sftp.py b/src/koheesio/integrations/spark/sftp.py index 90812b8..d983913 100644 --- a/src/koheesio/integrations/spark/sftp.py +++ b/src/koheesio/integrations/spark/sftp.py @@ -12,15 +12,17 @@ For more details on each mode, see the docstring of the SFTPWriteMode enum. """ -import hashlib -import time from typing import Optional, Union from enum import Enum +import hashlib from pathlib import Path +import time from paramiko.sftp_client import SFTPClient from paramiko.transport import Transport +from pydantic import PrivateAttr + from koheesio.models import ( Field, InstanceOf, @@ -152,8 +154,8 @@ class SFTPWriter(Writer): ) # private attrs - __client__: SFTPClient - __transport__: Transport + _client: Optional[SFTPClient] = PrivateAttr(default=None) + _transport: Optional[Transport] = PrivateAttr(default=None) @model_validator(mode="before") def validate_path_and_file_name(cls, data: dict) -> dict: @@ -203,26 +205,26 @@ def transport(self) -> Transport: If the username and password are provided, use them to connect to the SFTP server. """ - if not self.__transport__: - self.__transport__ = Transport((self.host, self.port)) + if not self._transport: + self._transport = Transport((self.host, self.port)) if self.username and self.password: - self.__transport__.connect( + self._transport.connect( username=self.username.get_secret_value(), password=self.password.get_secret_value() ) else: - self.__transport__.connect() - return self.__transport__ + self._transport.connect() + return self._transport @property def client(self) -> SFTPClient: """Return the SFTP client. If it doesn't exist, create it.""" - if not self.__client__: + if not self._client: try: - self.__client__ = SFTPClient.from_transport(self.transport) + self._client = SFTPClient.from_transport(self.transport) except EOFError as e: self.log.error(f"Failed to create SFTP client. Transport active: {self.transport.is_active()}") raise e - return self.__client__ + return self._client def _close_client(self) -> None: """Close the SFTP client and transport.""" diff --git a/src/koheesio/integrations/spark/snowflake.py b/src/koheesio/integrations/spark/snowflake.py index 59686d9..8a4ad9a 100644 --- a/src/koheesio/integrations/spark/snowflake.py +++ b/src/koheesio/integrations/spark/snowflake.py @@ -43,10 +43,10 @@ from __future__ import annotations -import json from typing import Any, Callable, Dict, List, Optional, Set, Union from abc import ABC from copy import deepcopy +import json from textwrap import dedent from pyspark.sql import Window @@ -989,7 +989,7 @@ def extract(self) -> DataFrame: raise RuntimeError( f"Source table {self.source_table.table_name} does not have CDF enabled. " f"Set TBLPROPERTIES ('delta.enableChangeDataFeed' = true) to enable. " - f"Current properties = {self.source_table_properties}" + f"Current properties = {self.source_table.get_persisted_properties()}" ) df = self.reader.read() @@ -1042,17 +1042,21 @@ class TagSnowflakeQuery(Step, ExtraParamsMixin): ------- #### Using `options` parameter ```python - query_tag = AddQueryTag( - options={"preactions": "ALTER SESSION"}, - task_name="cleanse_task", - pipeline_name="ingestion-pipeline", - etl_date="2022-01-01", - pipeline_execution_time="2022-01-01T00:00:00", - task_execution_time="2022-01-01T01:00:00", - environment="dev", - trace_id="acd4f3f96045", - span_id="546d2d66f6cb", - ).execute().options + query_tag = ( + AddQueryTag( + options={"preactions": "ALTER SESSION"}, + task_name="cleanse_task", + pipeline_name="ingestion-pipeline", + etl_date="2022-01-01", + pipeline_execution_time="2022-01-01T00:00:00", + task_execution_time="2022-01-01T01:00:00", + environment="dev", + trace_id="acd4f3f96045", + span_id="546d2d66f6cb", + ) + .execute() + .options + ) ``` In this example, the query tag pre-action will be added to the Snowflake options. diff --git a/src/koheesio/integrations/spark/tableau/hyper.py b/src/koheesio/integrations/spark/tableau/hyper.py index 992d9f1..94230d8 100644 --- a/src/koheesio/integrations/spark/tableau/hyper.py +++ b/src/koheesio/integrations/spark/tableau/hyper.py @@ -1,6 +1,6 @@ -import os from typing import Any, List, Optional, Union from abc import ABC, abstractmethod +import os from pathlib import PurePath from tempfile import TemporaryDirectory @@ -435,7 +435,8 @@ def clean_dataframe(self) -> DataFrame: if d_col.dataType.precision > 18: # noinspection PyUnresolvedReferences _df = _df.withColumn( - d_col.name, col(d_col.name).cast(DecimalType(precision=18, scale=d_col.dataType.scale)) # type: ignore + d_col.name, + col(d_col.name).cast(DecimalType(precision=18, scale=d_col.dataType.scale)), # type: ignore ) if len(decimal_col_names) > 0: _df = _df.na.fill(0.0, decimal_col_names) diff --git a/src/koheesio/integrations/spark/tableau/server.py b/src/koheesio/integrations/spark/tableau/server.py index 7770f62..14d80a6 100644 --- a/src/koheesio/integrations/spark/tableau/server.py +++ b/src/koheesio/integrations/spark/tableau/server.py @@ -1,9 +1,8 @@ -import os from typing import Any, ContextManager, Optional, Union from enum import Enum +import os from pathlib import PurePath -import urllib3 # type: ignore from tableauserverclient import ( DatasourceItem, PersonalAccessTokenAuth, @@ -12,6 +11,7 @@ ) from tableauserverclient.server.pager import Pager from tableauserverclient.server.server import Server +import urllib3 # type: ignore from pydantic import Field, SecretStr diff --git a/src/koheesio/logger.py b/src/koheesio/logger.py index cad2213..82474b6 100644 --- a/src/koheesio/logger.py +++ b/src/koheesio/logger.py @@ -29,12 +29,12 @@ from __future__ import annotations +from typing import Any, Dict, Generator, Generic, List, Optional, Tuple, TypeVar import inspect import logging +from logging import Formatter, Logger, LogRecord, getLogger import os import sys -from typing import Any, Dict, Generator, Generic, List, Optional, Tuple, TypeVar -from logging import Formatter, Logger, LogRecord, getLogger from uuid import uuid4 from warnings import warn diff --git a/src/koheesio/models/__init__.py b/src/koheesio/models/__init__.py index dd0a8c8..a2db492 100644 --- a/src/koheesio/models/__init__.py +++ b/src/koheesio/models/__init__.py @@ -11,10 +11,10 @@ from __future__ import annotations +from typing import Annotated, Any, Dict, List, Optional, Union from abc import ABC from functools import cached_property from pathlib import Path -from typing import Annotated, Any, Dict, List, Optional, Union # to ensure that koheesio.models is a drop in replacement for pydantic from pydantic import BaseModel as PydanticBaseModel @@ -407,7 +407,9 @@ def __add__(self, other: Union[Dict, BaseModel]) -> BaseModel: ```python step_output_1 = StepOutput(foo="bar") step_output_2 = StepOutput(lorem="ipsum") - (step_output_1 + step_output_2) # step_output_1 will now contain {'foo': 'bar', 'lorem': 'ipsum'} + ( + step_output_1 + step_output_2 + ) # step_output_1 will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Parameters @@ -531,7 +533,9 @@ def merge(self, other: Union[Dict, BaseModel]) -> BaseModel: -------- ```python step_output = StepOutput(foo="bar") - step_output.merge({"lorem": "ipsum"}) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} + step_output.merge( + {"lorem": "ipsum"} + ) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Parameters diff --git a/src/koheesio/models/reader.py b/src/koheesio/models/reader.py index 3f35192..c122794 100644 --- a/src/koheesio/models/reader.py +++ b/src/koheesio/models/reader.py @@ -2,8 +2,8 @@ Module for the BaseReader class """ -from abc import ABC, abstractmethod from typing import Optional, TypeVar +from abc import ABC, abstractmethod from koheesio import Step diff --git a/src/koheesio/models/sql.py b/src/koheesio/models/sql.py index f19bc96..86ad38e 100644 --- a/src/koheesio/models/sql.py +++ b/src/koheesio/models/sql.py @@ -1,8 +1,8 @@ """This module contains the base class for SQL steps.""" +from typing import Any, Dict, Optional, Union from abc import ABC from pathlib import Path -from typing import Any, Dict, Optional, Union from koheesio import Step from koheesio.models import ExtraParamsMixin, Field, model_validator diff --git a/src/koheesio/notifications/slack.py b/src/koheesio/notifications/slack.py index 423b2f3..d38655b 100644 --- a/src/koheesio/notifications/slack.py +++ b/src/koheesio/notifications/slack.py @@ -2,9 +2,9 @@ Classes to ease interaction with Slack """ +from typing import Any, Dict, Optional import datetime import json -from typing import Any, Dict, Optional from textwrap import dedent from koheesio.models import ConfigDict, Field diff --git a/src/koheesio/spark/__init__.py b/src/koheesio/spark/__init__.py index c72cfb0..2beccf4 100644 --- a/src/koheesio/spark/__init__.py +++ b/src/koheesio/spark/__init__.py @@ -4,9 +4,9 @@ from __future__ import annotations -import warnings -from abc import ABC from typing import Optional +from abc import ABC +import warnings from pydantic import Field diff --git a/src/koheesio/spark/delta.py b/src/koheesio/spark/delta.py index 8d252a6..e4ef31c 100644 --- a/src/koheesio/spark/delta.py +++ b/src/koheesio/spark/delta.py @@ -2,8 +2,8 @@ Module for creating and managing Delta tables. """ -import warnings from typing import Dict, List, Optional, Union +import warnings from py4j.protocol import Py4JJavaError # type: ignore diff --git a/src/koheesio/spark/readers/hana.py b/src/koheesio/spark/readers/hana.py index 7616856..a98bed8 100644 --- a/src/koheesio/spark/readers/hana.py +++ b/src/koheesio/spark/readers/hana.py @@ -27,11 +27,12 @@ class HanaReader(JdbcReader): ```python from koheesio.spark.readers.hana import HanaReader + jdbc_hana = HanaReader( url="jdbc:sap://:/?", user="YOUR_USERNAME", password="***", - dbtable="schema_name.table_name" + dbtable="schema_name.table_name", ) df = jdbc_hana.read() ``` diff --git a/src/koheesio/spark/readers/memory.py b/src/koheesio/spark/readers/memory.py index 7900205..a90e09e 100644 --- a/src/koheesio/spark/readers/memory.py +++ b/src/koheesio/spark/readers/memory.py @@ -2,13 +2,14 @@ Create Spark DataFrame directly from the data stored in a Python variable """ -import json +from typing import Any, Dict, Optional, Union from enum import Enum from functools import partial from io import StringIO -from typing import Any, Dict, Optional, Union +import json import pandas as pd + from pyspark.sql.types import StructType from koheesio.models import ExtraParamsMixin, Field diff --git a/src/koheesio/spark/snowflake.py b/src/koheesio/spark/snowflake.py index 67cab02..4a70b8f 100644 --- a/src/koheesio/spark/snowflake.py +++ b/src/koheesio/spark/snowflake.py @@ -41,10 +41,10 @@ environments and make sure to install required JARs. """ -import json from typing import Any, Callable, Dict, List, Optional, Set, Union from abc import ABC from copy import deepcopy +import json from textwrap import dedent from pyspark.sql import Window @@ -666,9 +666,7 @@ def get_query(self, role: str) -> str: query : str The Query that performs the grant """ - query = ( - f"GRANT {','.join(self.privileges)} ON {self.type} {self.object} TO ROLE {role}".upper() - ) # nosec B608: hardcoded_sql_expressions + query = f"GRANT {','.join(self.privileges)} ON {self.type} {self.object} TO ROLE {role}".upper() # nosec B608: hardcoded_sql_expressions return query def execute(self) -> SnowflakeStep.Output: @@ -950,17 +948,21 @@ class TagSnowflakeQuery(Step, ExtraParamsMixin): Example ------- ```python - query_tag = AddQueryTag( - options={"preactions": ...}, - task_name="cleanse_task", - pipeline_name="ingestion-pipeline", - etl_date="2022-01-01", - pipeline_execution_time="2022-01-01T00:00:00", - task_execution_time="2022-01-01T01:00:00", - environment="dev", - trace_id="acd4f3f96045", - span_id="546d2d66f6cb", - ).execute().options + query_tag = ( + AddQueryTag( + options={"preactions": ...}, + task_name="cleanse_task", + pipeline_name="ingestion-pipeline", + etl_date="2022-01-01", + pipeline_execution_time="2022-01-01T00:00:00", + task_execution_time="2022-01-01T01:00:00", + environment="dev", + trace_id="acd4f3f96045", + span_id="546d2d66f6cb", + ) + .execute() + .options + ) ``` """ @@ -1320,7 +1322,7 @@ def extract(self) -> DataFrame: raise RuntimeError( f"Source table {self.source_table.table_name} does not have CDF enabled. " f"Set TBLPROPERTIES ('delta.enableChangeDataFeed' = true) to enable. " - f"Current properties = {self.source_table_properties}" + f"Current properties = {self.source_table.get_persisted_properties()}" ) df = self.reader.read() diff --git a/src/koheesio/spark/transformations/__init__.py b/src/koheesio/spark/transformations/__init__.py index 3f273a8..c44a998 100644 --- a/src/koheesio/spark/transformations/__init__.py +++ b/src/koheesio/spark/transformations/__init__.py @@ -56,7 +56,9 @@ class Transformation(SparkStep, ABC): class AddOne(Transformation): def execute(self): - self.output.df = self.df.withColumn("new_column", f.col("old_column") + 1) + self.output.df = self.df.withColumn( + "new_column", f.col("old_column") + 1 + ) ``` In the example above, the `execute` method is implemented to add 1 to the values of the `old_column` and store the diff --git a/src/koheesio/spark/transformations/camel_to_snake.py b/src/koheesio/spark/transformations/camel_to_snake.py index 33f5d23..f62c822 100644 --- a/src/koheesio/spark/transformations/camel_to_snake.py +++ b/src/koheesio/spark/transformations/camel_to_snake.py @@ -2,8 +2,8 @@ Class for converting DataFrame column names from camel case to snake case. """ -import re from typing import Optional +import re from koheesio.models import Field, ListOfColumns from koheesio.spark.transformations import ColumnsTransformation diff --git a/src/koheesio/spark/transformations/date_time/interval.py b/src/koheesio/spark/transformations/date_time/interval.py index e30244a..e56bd54 100644 --- a/src/koheesio/spark/transformations/date_time/interval.py +++ b/src/koheesio/spark/transformations/date_time/interval.py @@ -304,9 +304,7 @@ def adjust_time(column: Column, operation: Operations, interval: str) -> Column: operation = { "add": "try_add", "subtract": "try_subtract", - }[ - operation - ] # type: ignore + }[operation] # type: ignore except KeyError as e: raise ValueError(f"Operation '{operation}' is not valid. Must be either 'add' or 'subtract'.") from e diff --git a/src/koheesio/spark/transformations/drop_column.py b/src/koheesio/spark/transformations/drop_column.py index d4da777..9cd114e 100644 --- a/src/koheesio/spark/transformations/drop_column.py +++ b/src/koheesio/spark/transformations/drop_column.py @@ -46,5 +46,5 @@ class DropColumn(ColumnsTransformation): """ def execute(self) -> ColumnsTransformation.Output: - self.log.info(f"{self.column=}") + self.log.info(f"{self.columns=}") self.output.df = self.df.drop(*self.columns) diff --git a/src/koheesio/spark/transformations/strings/trim.py b/src/koheesio/spark/transformations/strings/trim.py index ce116e2..6f72eae 100644 --- a/src/koheesio/spark/transformations/strings/trim.py +++ b/src/koheesio/spark/transformations/strings/trim.py @@ -15,8 +15,8 @@ from typing import Literal -import pyspark.sql.functions as f from pyspark.sql import Column +import pyspark.sql.functions as f from koheesio.models import Field, ListOfColumns from koheesio.spark.transformations import ColumnsTransformationWithTarget diff --git a/src/koheesio/spark/transformations/uuid5.py b/src/koheesio/spark/transformations/uuid5.py index 545a2f9..cd709a8 100644 --- a/src/koheesio/spark/transformations/uuid5.py +++ b/src/koheesio/spark/transformations/uuid5.py @@ -1,7 +1,7 @@ """Ability to generate UUID5 using native pyspark (no udf)""" -import uuid from typing import Optional, Union +import uuid from pyspark.sql import functions as f diff --git a/src/koheesio/spark/utils/common.py b/src/koheesio/spark/utils/common.py index 10050d5..1f9b47c 100644 --- a/src/koheesio/spark/utils/common.py +++ b/src/koheesio/spark/utils/common.py @@ -2,11 +2,11 @@ Spark Utility functions """ +from typing import Union +from enum import Enum import importlib import inspect import os -from typing import Union -from enum import Enum from types import ModuleType from pyspark import sql diff --git a/src/koheesio/spark/writers/buffer.py b/src/koheesio/spark/writers/buffer.py index e94b5f8..e83880e 100644 --- a/src/koheesio/spark/writers/buffer.py +++ b/src/koheesio/spark/writers/buffer.py @@ -15,10 +15,10 @@ from __future__ import annotations -import gzip from typing import AnyStr, Literal, Optional from abc import ABC from functools import partial +import gzip from os import linesep from tempfile import SpooledTemporaryFile @@ -252,6 +252,15 @@ class PandasCsvBufferWriter(BufferWriter, ExtraParamsMixin): "by default. Can be set to one of 'infer', 'gzip', 'bz2', 'zip', 'xz', 'zstd', or 'tar'. " "See Pandas documentation for more details.", ) + emptyValue: Optional[str] = Field( + default="", + description="The string to use for missing values. Koheesio sets this default to an empty string.", + ) + + nullValue: Optional[str] = Field( + default="", + description="The string to use for missing values. Koheesio sets this default to an empty string.", + ) # -- Pandas specific properties -- index: bool = Field( diff --git a/src/koheesio/spark/writers/delta/batch.py b/src/koheesio/spark/writers/delta/batch.py index 7fd8376..6959ef0 100644 --- a/src/koheesio/spark/writers/delta/batch.py +++ b/src/koheesio/spark/writers/delta/batch.py @@ -34,18 +34,19 @@ ``` """ +from typing import Callable, Dict, List, Optional, Set, Type, Union from functools import partial -from typing import List, Optional, Set, Type, Union -from delta.tables import DeltaMergeBuilder, DeltaTable +from delta.tables import DeltaMergeBuilder from py4j.protocol import Py4JError + from pyspark.sql import DataFrameWriter from koheesio.models import ExtraParamsMixin, Field, field_validator from koheesio.spark.delta import DeltaTableStep from koheesio.spark.utils import on_databricks from koheesio.spark.writers import BatchOutputMode, StreamingOutputMode, Writer -from koheesio.spark.writers.delta.utils import log_clauses +from koheesio.spark.writers.delta.utils import get_delta_table_for_name, log_clauses class DeltaTableWriter(Writer, ExtraParamsMixin): @@ -157,8 +158,9 @@ def __merge(self, merge_builder: Optional[DeltaMergeBuilder] = None) -> Union[De if self.table.exists: merge_builder = self._get_merge_builder(merge_builder) + from koheesio.spark.utils.connect import is_remote_session - if on_databricks(): + if on_databricks() and not is_remote_session(): try: source_alias = merge_builder._jbuilder.getMergePlan().source().alias() target_alias = merge_builder._jbuilder.getMergePlan().target().alias() @@ -219,7 +221,7 @@ def __merge_all(self) -> Union[DeltaMergeBuilder, DataFrameWriter]: if self.table.exists: builder = ( - DeltaTable.forName(sparkSession=self.spark, tableOrViewName=self.table.table_name) + get_delta_table_for_name(spark_session=self.spark, table_name=self.table.table_name) .alias(target_alias) .merge(source=self.df.alias(source_alias), condition=merge_cond) .whenMatchedUpdateAll(condition=update_cond) @@ -266,7 +268,7 @@ def _merge_builder_from_args(self) -> DeltaMergeBuilder: target_alias = self.params.get("target_alias", "target") builder = ( - DeltaTable.forName(self.spark, self.table.table_name) + get_delta_table_for_name(spark_session=self.spark, table_name=self.table.table_name) .alias(target_alias) .merge(self.df.alias(source_alias), merge_cond) ) @@ -359,7 +361,7 @@ def __data_frame_writer(self) -> DataFrameWriter: @property def writer(self) -> Union[DeltaMergeBuilder, DataFrameWriter]: """Specify DeltaTableWriter""" - map_mode_to_writer = { + map_mode_to_writer: Dict[str, Callable] = { BatchOutputMode.MERGEALL.value: self.__merge_all, BatchOutputMode.MERGE.value: self.__merge, } diff --git a/src/koheesio/spark/writers/delta/scd.py b/src/koheesio/spark/writers/delta/scd.py index f93762e..6f0087e 100644 --- a/src/koheesio/spark/writers/delta/scd.py +++ b/src/koheesio/spark/writers/delta/scd.py @@ -30,6 +30,7 @@ from koheesio.spark.delta import DeltaTableStep from koheesio.spark.functions import current_timestamp_utc from koheesio.spark.writers import Writer +from koheesio.spark.writers.delta.utils import get_delta_table_for_name class SCD2DeltaTableWriter(Writer): @@ -476,7 +477,7 @@ def execute(self) -> None: """ self.df: DataFrame self.spark: SparkSession - delta_table = DeltaTable.forName(sparkSession=self.spark, tableOrViewName=self.table.table_name) + delta_table = get_delta_table_for_name(spark_session=self.spark, table_name=self.table.table_name) src_alias, cross_alias, dest_alias = "src", "cross", "tgt" # Prepare required merge columns diff --git a/src/koheesio/spark/writers/delta/stream.py b/src/koheesio/spark/writers/delta/stream.py index aea03a5..ada8f5b 100644 --- a/src/koheesio/spark/writers/delta/stream.py +++ b/src/koheesio/spark/writers/delta/stream.py @@ -2,8 +2,8 @@ This module defines the DeltaTableStreamWriter class, which is used to write streaming dataframes to Delta tables. """ -from email.policy import default from typing import Optional +from email.policy import default from pydantic import Field diff --git a/src/koheesio/spark/writers/delta/utils.py b/src/koheesio/spark/writers/delta/utils.py index 2e08a16..03c5d75 100644 --- a/src/koheesio/spark/writers/delta/utils.py +++ b/src/koheesio/spark/writers/delta/utils.py @@ -4,7 +4,21 @@ from typing import Optional -from py4j.java_gateway import JavaObject # type: ignore[import-untyped] +from delta import DeltaTable +from py4j.java_gateway import JavaObject + +from koheesio.spark import SparkSession +from koheesio.spark.utils import SPARK_MINOR_VERSION + + +class SparkConnectDeltaTableException(AttributeError): + EXCEPTION_CONNECT_TEXT: str = """`DeltaTable.forName` is not supported due to delta calling _sc, + which is not available in Spark Connect and PySpark>=3.5,<4.0. Required version of PySpark >=4.0. + Possible workaround to use spark.read and Spark SQL for any Delta operation (e.g. merge)""" + + def __init__(self, original_exception: AttributeError): + custom_message = f"{self.EXCEPTION_CONNECT_TEXT}\nOriginal exception: {str(original_exception)}" + super().__init__(custom_message) def log_clauses(clauses: JavaObject, source_alias: str, target_alias: str) -> Optional[str]: @@ -68,3 +82,41 @@ def log_clauses(clauses: JavaObject, source_alias: str, target_alias: str) -> Op ) return log_message + + +def get_delta_table_for_name(spark_session: SparkSession, table_name: str) -> DeltaTable: + """ + Retrieves the DeltaTable instance for the specified table name. + + This method attempts to get the DeltaTable using the provided Spark session and table name. + If an AttributeError occurs and the Spark version is between 3.4 and 4.0, and the session is remote, + it raises a SparkConnectDeltaTableException. + + Parameters + ---------- + spark_session : SparkSession + The Spark Session to use. + table_name : str + The table name. + + Returns + ------- + DeltaTable + The DeltaTable instance for the specified table name. + + Raises + ------ + SparkConnectDeltaTableException + If the Spark version is between 3.4 and 4.0, the session is remote, and an AttributeError occurs. + """ + try: + delta_table = DeltaTable.forName(sparkSession=spark_session, tableOrViewName=table_name) + except AttributeError as e: + from koheesio.spark.utils.connect import is_remote_session + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + raise SparkConnectDeltaTableException(e) from e + else: + raise e + + return delta_table diff --git a/src/koheesio/spark/writers/dummy.py b/src/koheesio/spark/writers/dummy.py index 5e90a98..a618381 100644 --- a/src/koheesio/spark/writers/dummy.py +++ b/src/koheesio/spark/writers/dummy.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Union from koheesio.models import Field, PositiveInt, field_validator +from koheesio.spark import DataFrame from koheesio.spark.utils import show_string from koheesio.spark.writers import Writer diff --git a/src/koheesio/steps/__init__.py b/src/koheesio/steps/__init__.py index 89c04ba..5a1faa7 100644 --- a/src/koheesio/steps/__init__.py +++ b/src/koheesio/steps/__init__.py @@ -16,18 +16,18 @@ from __future__ import annotations +from typing import Any, Callable, Optional +from abc import abstractmethod +from functools import partialmethod, wraps import inspect import json import sys import warnings -from typing import Any, Callable, Union -from abc import abstractmethod -from functools import partialmethod, wraps import yaml from pydantic import BaseModel as PydanticBaseModel -from pydantic import InstanceOf +from pydantic import InstanceOf, PrivateAttr from koheesio.models import BaseModel, ConfigDict, ModelMetaclass @@ -535,21 +535,21 @@ def execute(self) -> MyStep.Output: class Output(StepOutput): """Output class for Step""" - __output__: Output + _output: Optional[Output] = PrivateAttr(default=None) @property def output(self) -> Output: """Interact with the output of the Step""" - if not self.__output__: - self.__output__ = self.Output.lazy() - self.__output__.name = self.name + ".Output" # type: ignore[operator] - self.__output__.description = "Output for " + self.name # type: ignore[operator] - return self.__output__ + if not self._output: + self._output = self.Output.lazy() + self._output.name = self.name + ".Output" # type: ignore + self._output.description = "Output for " + self.name # type: ignore + return self._output @output.setter def output(self, value: Output) -> None: """Set the output of the Step""" - self.__output__ = value + self._output = value @abstractmethod def execute(self) -> InstanceOf[StepOutput]: @@ -675,23 +675,6 @@ def repr_yaml(self, simple: bool = False) -> str: return yaml.dump(_result) - def __getattr__(self, key: str) -> Union[Any, None]: - """__getattr__ dunder - - Allows input to be accessed through `self.input_name` - - Parameters - ---------- - key: str - Name of the attribute to return the value of - - Returns - ------- - Any - The value of the attribute - """ - return self.model_dump().get(key) - @classmethod def from_step(cls, step: Step, **kwargs) -> InstanceOf[PydanticBaseModel]: # type: ignore[no-untyped-def] """Returns a new Step instance based on the data of another Step or BaseModel instance""" diff --git a/src/koheesio/steps/http.py b/src/koheesio/steps/http.py index 68329cc..8a16a8f 100644 --- a/src/koheesio/steps/http.py +++ b/src/koheesio/steps/http.py @@ -12,9 +12,9 @@ In the above example, the `response` variable will contain the JSON response from the HTTP request. """ -import json from typing import Any, Dict, List, Optional, Union from enum import Enum +import json import requests # type: ignore[import-untyped] diff --git a/src/koheesio/utils.py b/src/koheesio/utils.py index 253a985..0556a39 100644 --- a/src/koheesio/utils.py +++ b/src/koheesio/utils.py @@ -2,14 +2,14 @@ Utility functions """ -import datetime -import inspect -import uuid from typing import Any, Callable, Dict, Optional, Tuple +import datetime from functools import partial from importlib import import_module +import inspect from pathlib import Path from sys import version_info as PYTHON_VERSION +import uuid __all__ = [ "get_args_for_func", diff --git a/tests/asyncio/test_asyncio_http.py b/tests/asyncio/test_asyncio_http.py index 8625c71..5dcbc11 100644 --- a/tests/asyncio/test_asyncio_http.py +++ b/tests/asyncio/test_asyncio_http.py @@ -1,8 +1,8 @@ import warnings -import pytest from aiohttp import ClientResponseError, ClientSession, TCPConnector from aiohttp_retry import ExponentialRetry +import pytest from yarl import URL from pydantic import ValidationError diff --git a/tests/conftest.py b/tests/conftest.py index a0090a0..36e7d74 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import os +from pathlib import Path import time import uuid -from pathlib import Path import pytest diff --git a/tests/core/test_logger.py b/tests/core/test_logger.py index 277d766..c739ae8 100644 --- a/tests/core/test_logger.py +++ b/tests/core/test_logger.py @@ -1,5 +1,5 @@ -import logging from io import StringIO +import logging from logging import Logger from unittest.mock import MagicMock, patch diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 3d981f9..29e0ea6 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -1,5 +1,5 @@ -import json from typing import Optional +import json from textwrap import dedent import pytest diff --git a/tests/snowflake/test_snowflake.py b/tests/snowflake/test_snowflake.py index 0541bdf..8cf2fb4 100644 --- a/tests/snowflake/test_snowflake.py +++ b/tests/snowflake/test_snowflake.py @@ -1,8 +1,8 @@ # flake8: noqa: F811 from unittest import mock -import pytest from pydantic_core._pydantic_core import ValidationError +import pytest from koheesio.integrations.snowflake import ( GrantPrivilegesOnObject, @@ -158,7 +158,6 @@ def test_with_missing_dependencies(self): class TestSnowflakeBaseModel: - def test_get_options_using_alias(self): """Test that the options are correctly generated using alias""" k = SnowflakeBaseModel( diff --git a/tests/spark/conftest.py b/tests/spark/conftest.py index b0a7c51..f918ae4 100644 --- a/tests/spark/conftest.py +++ b/tests/spark/conftest.py @@ -1,15 +1,15 @@ +from collections import namedtuple import datetime +from decimal import Decimal import os +from pathlib import Path import socket import sys -from collections import namedtuple -from decimal import Decimal -from pathlib import Path from textwrap import dedent from unittest import mock -import pytest from delta import configure_spark_with_delta_pip +import pytest from pyspark.sql import SparkSession from pyspark.sql.types import ( diff --git a/tests/spark/integrations/snowflake/test_sync_task.py b/tests/spark/integrations/snowflake/test_sync_task.py index a4c50e8..980e03e 100644 --- a/tests/spark/integrations/snowflake/test_sync_task.py +++ b/tests/spark/integrations/snowflake/test_sync_task.py @@ -3,13 +3,12 @@ from unittest import mock import chispa -import pytest from conftest import await_job_completion +import pytest import pydantic from koheesio.integrations.snowflake import SnowflakeRunQueryPython -from koheesio.integrations.snowflake.test_utils import mock_query from koheesio.integrations.spark.snowflake import ( SnowflakeWriter, SynchronizeDeltaToSnowflakeTask, diff --git a/tests/spark/readers/test_auto_loader.py b/tests/spark/readers/test_auto_loader.py index 8f2b168..71e6cea 100644 --- a/tests/spark/readers/test_auto_loader.py +++ b/tests/spark/readers/test_auto_loader.py @@ -1,5 +1,5 @@ -import pytest from chispa import assert_df_equality +import pytest from pyspark.sql.types import * diff --git a/tests/spark/readers/test_memory.py b/tests/spark/readers/test_memory.py index 21b5d53..19c5b5c 100644 --- a/tests/spark/readers/test_memory.py +++ b/tests/spark/readers/test_memory.py @@ -1,5 +1,6 @@ -import pytest from chispa import assert_df_equality +import pytest + from pyspark.sql.types import StructType from koheesio.spark.readers.memory import DataFormat, InMemoryDataReader diff --git a/tests/spark/readers/test_rest_api.py b/tests/spark/readers/test_rest_api.py index 9c22ea3..0803f63 100644 --- a/tests/spark/readers/test_rest_api.py +++ b/tests/spark/readers/test_rest_api.py @@ -1,7 +1,7 @@ -import pytest -import requests_mock from aiohttp import ClientSession, TCPConnector from aiohttp_retry import ExponentialRetry +import pytest +import requests_mock from yarl import URL from pyspark.sql.types import MapType, StringType, StructField, StructType diff --git a/tests/spark/test_delta.py b/tests/spark/test_delta.py index 1806ac0..920d2b6 100644 --- a/tests/spark/test_delta.py +++ b/tests/spark/test_delta.py @@ -3,8 +3,8 @@ from pathlib import Path from unittest.mock import patch -import pytest from conftest import setup_test_data +import pytest from pydantic import ValidationError diff --git a/tests/spark/writers/delta/test_delta_writer.py b/tests/spark/writers/delta/test_delta_writer.py index 92a349c..c916b0d 100644 --- a/tests/spark/writers/delta/test_delta_writer.py +++ b/tests/spark/writers/delta/test_delta_writer.py @@ -1,9 +1,8 @@ import os from unittest.mock import MagicMock, patch -import pytest from conftest import await_job_completion -from delta import DeltaTable +import pytest from pydantic import ValidationError @@ -14,13 +13,14 @@ from koheesio.spark.utils import SPARK_MINOR_VERSION from koheesio.spark.writers import BatchOutputMode, StreamingOutputMode from koheesio.spark.writers.delta import DeltaTableStreamWriter, DeltaTableWriter -from koheesio.spark.writers.delta.utils import log_clauses +from koheesio.spark.writers.delta.utils import ( + SparkConnectDeltaTableException, + log_clauses, +) from koheesio.spark.writers.stream import Trigger pytestmark = pytest.mark.spark -skip_reason = "Tests are not working with PySpark 3.5 due to delta calling _sc. Test requires pyspark version >= 4.0" - def test_delta_table_writer(dummy_df, spark): table_name = "test_table" @@ -53,9 +53,6 @@ def test_delta_partitioning(spark, sample_df_to_partition): def test_delta_table_merge_all(spark): from koheesio.spark.utils.connect import is_remote_session - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) - table_name = "test_merge_all_table" target_df = spark.createDataFrame( [{"id": 1, "value": "no_merge"}, {"id": 2, "value": "expected_merge"}, {"id": 5, "value": "xxxx"}] @@ -77,7 +74,7 @@ def test_delta_table_merge_all(spark): 5: "xxxx", } DeltaTableWriter(table=table_name, output_mode=BatchOutputMode.APPEND, df=target_df).execute() - DeltaTableWriter( + merge_writer = DeltaTableWriter( table=table_name, output_mode=BatchOutputMode.MERGEALL, output_mode_params={ @@ -86,33 +83,46 @@ def test_delta_table_merge_all(spark): "insert_cond": F.expr("source.value IS NOT NULL"), }, df=source_df, - ).execute() - result = { - list(row.asDict().values())[0]: list(row.asDict().values())[1] for row in spark.read.table(table_name).collect() - } - assert result == expected + ) + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + merge_writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + merge_writer.execute() + result = { + list(row.asDict().values())[0]: list(row.asDict().values())[1] + for row in spark.read.table(table_name).collect() + } + assert result == expected def test_deltatablewriter_with_invalid_conditions(spark, dummy_df): from koheesio.spark.utils.connect import is_remote_session - - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) + from koheesio.spark.writers.delta.utils import get_delta_table_for_name table_name = "delta_test_table" - merge_builder = ( - DeltaTable.forName(sparkSession=spark, tableOrViewName=table_name) - .alias("target") - .merge(condition="invalid_condition", source=dummy_df.alias("source")) - ) - writer = DeltaTableWriter( - table=table_name, - output_mode=BatchOutputMode.MERGE, - output_mode_params={"merge_builder": merge_builder}, - df=dummy_df, - ) - with pytest.raises(AnalysisException): - writer.execute() + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + builder = get_delta_table_for_name(spark_session=spark, table_name=table_name) + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + with pytest.raises(AnalysisException): + builder = get_delta_table_for_name(spark_session=spark, table_name=table_name) + merge_builder = builder.alias("target").merge( + condition="invalid_condition", source=dummy_df.alias("source") + ) + writer = DeltaTableWriter( + table=table_name, + output_mode=BatchOutputMode.MERGE, + output_mode_params={"merge_builder": merge_builder}, + df=dummy_df, + ) + writer.execute() @patch.dict( @@ -286,9 +296,6 @@ def test_delta_with_options(spark): def test_merge_from_args(spark, dummy_df): from koheesio.spark.utils.connect import is_remote_session - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) - table_name = "test_table_merge_from_args" dummy_df.write.format("delta").saveAsTable(table_name) @@ -316,14 +323,20 @@ def test_merge_from_args(spark, dummy_df): "merge_cond": "source.id=target.id", }, ) - writer._merge_builder_from_args() - mock_delta_builder.whenMatchedUpdate.assert_called_once_with( - set={"id": "source.id"}, condition="source.id=target.id" - ) - mock_delta_builder.whenNotMatchedInsert.assert_called_once_with( - values={"id": "source.id"}, condition="source.id IS NOT NULL" - ) + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer._merge_builder_from_args() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer._merge_builder_from_args() + mock_delta_builder.whenMatchedUpdate.assert_called_once_with( + set={"id": "source.id"}, condition="source.id=target.id" + ) + mock_delta_builder.whenNotMatchedInsert.assert_called_once_with( + values={"id": "source.id"}, condition="source.id IS NOT NULL" + ) @pytest.mark.parametrize( @@ -350,9 +363,6 @@ def test_merge_from_args_raise_value_error(spark, output_mode_params): def test_merge_no_table(spark): from koheesio.spark.utils.connect import is_remote_session - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) - table_name = "test_merge_no_table" target_df = spark.createDataFrame( [{"id": 1, "value": "no_merge"}, {"id": 2, "value": "expected_merge"}, {"id": 5, "value": "expected_merge"}] @@ -388,20 +398,29 @@ def test_merge_no_table(spark): ], "merge_cond": "source.id=target.id", } - - DeltaTableWriter( + writer1 = DeltaTableWriter( df=target_df, table=table_name, output_mode=BatchOutputMode.MERGE, output_mode_params=params - ).execute() - - DeltaTableWriter( + ) + writer2 = DeltaTableWriter( df=source_df, table=table_name, output_mode=BatchOutputMode.MERGE, output_mode_params=params - ).execute() + ) + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + writer1.execute() - result = { - list(row.asDict().values())[0]: list(row.asDict().values())[1] for row in spark.read.table(table_name).collect() - } + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer2.execute() - assert result == expected + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer1.execute() + writer2.execute() + + result = { + list(row.asDict().values())[0]: list(row.asDict().values())[1] + for row in spark.read.table(table_name).collect() + } + + assert result == expected def test_log_clauses(mocker): diff --git a/tests/spark/writers/delta/test_scd.py b/tests/spark/writers/delta/test_scd.py index 087f957..df4106b 100644 --- a/tests/spark/writers/delta/test_scd.py +++ b/tests/spark/writers/delta/test_scd.py @@ -1,9 +1,9 @@ -import datetime from typing import List, Optional +import datetime -import pytest from delta import DeltaTable from delta.tables import DeltaMergeBuilder +import pytest from pydantic import Field @@ -16,18 +16,14 @@ from koheesio.spark.functions import current_timestamp_utc from koheesio.spark.utils import SPARK_MINOR_VERSION from koheesio.spark.writers.delta.scd import SCD2DeltaTableWriter +from koheesio.spark.writers.delta.utils import SparkConnectDeltaTableException pytestmark = pytest.mark.spark -skip_reason = "Tests are not working with PySpark 3.5 due to delta calling _sc. Test requires pyspark version >= 4.0" - def test_scd2_custom_logic(spark): from koheesio.spark.utils.connect import is_remote_session - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) - def _get_result(target_df: DataFrame, expr: str): res = ( target_df.where(expr) @@ -145,124 +141,161 @@ def _prepare_merge_builder( meta_scd2_end_time_col_name="valid_to_timestamp", df=source_df, ) - writer.execute() - expected = { - "id": 4, - "last_updated_at": datetime.datetime(2024, 4, 1, 0, 0), - "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), - "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), - "value": "value-4", - } + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() - target_df = spark.read.table(target_table) - result = _get_result(target_df, "id = 4") + expected = { + "id": 4, + "last_updated_at": datetime.datetime(2024, 4, 1, 0, 0), + "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), + "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), + "value": "value-4", + } - assert spark.table(target_table).count() == 4 - assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 - assert result == expected + target_df = spark.read.table(target_table) + result = _get_result(target_df, "id = 4") + + assert spark.table(target_table).count() == 4 + assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 + assert result == expected source_df2 = source_df.withColumn( "value", F.expr("CASE WHEN id = 2 THEN 'value-2-change' ELSE value END") ).withColumn("last_updated_at", F.expr("CASE WHEN id = 2 THEN TIMESTAMP'2024-02-02' ELSE last_updated_at END")) writer.df = source_df2 - writer.execute() - - expected_insert = { - "id": 2, - "last_updated_at": datetime.datetime(2024, 2, 2, 0, 0), - "valid_from_timestamp": datetime.datetime(2024, 2, 2, 0, 0), - "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), - "value": "value-2-change", - } - - expected_update = { - "id": 2, - "last_updated_at": datetime.datetime(2024, 2, 1, 0, 0), - "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), - "valid_to_timestamp": datetime.datetime(2024, 2, 2, 0, 0), - "value": "value-2", - } - - result_insert = _get_result(target_df, "id = 2 and meta.valid_to_timestamp = '2999-12-31'") - result_update = _get_result(target_df, "id = 2 and meta.valid_from_timestamp = '1970-01-01'") - - assert spark.table(target_table).count() == 5 - assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 - assert result_insert == expected_insert - assert result_update == expected_update + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() + + expected_insert = { + "id": 2, + "last_updated_at": datetime.datetime(2024, 2, 2, 0, 0), + "valid_from_timestamp": datetime.datetime(2024, 2, 2, 0, 0), + "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), + "value": "value-2-change", + } + + expected_update = { + "id": 2, + "last_updated_at": datetime.datetime(2024, 2, 1, 0, 0), + "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), + "valid_to_timestamp": datetime.datetime(2024, 2, 2, 0, 0), + "value": "value-2", + } + + result_insert = _get_result(target_df, "id = 2 and meta.valid_to_timestamp = '2999-12-31'") + result_update = _get_result(target_df, "id = 2 and meta.valid_from_timestamp = '1970-01-01'") + + assert spark.table(target_table).count() == 5 + assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 + assert result_insert == expected_insert + assert result_update == expected_update source_df3 = source_df2.withColumn( "value", F.expr("CASE WHEN id = 3 THEN 'value-3-change' ELSE value END") ).withColumn("last_updated_at", F.expr("CASE WHEN id = 3 THEN TIMESTAMP'2024-03-02' ELSE last_updated_at END")) writer.df = source_df3 - writer.execute() - - expected_insert = { - "id": 3, - "last_updated_at": datetime.datetime(2024, 3, 2, 0, 0), - "valid_from_timestamp": datetime.datetime(2024, 3, 2, 0, 0), - "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), - "value": "value-3-change", - } - - expected_update = { - "id": 3, - "last_updated_at": datetime.datetime(2024, 3, 1, 0, 0), - "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), - "valid_to_timestamp": datetime.datetime(2024, 3, 2, 0, 0), - "value": None, - } - - result_insert = _get_result(target_df, "id = 3 and meta.valid_to_timestamp = '2999-12-31'") - result_update = _get_result(target_df, "id = 3 and meta.valid_from_timestamp = '1970-01-01'") - - assert spark.table(target_table).count() == 6 - assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 - assert result_insert == expected_insert - assert result_update == expected_update + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() + + expected_insert = { + "id": 3, + "last_updated_at": datetime.datetime(2024, 3, 2, 0, 0), + "valid_from_timestamp": datetime.datetime(2024, 3, 2, 0, 0), + "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), + "value": "value-3-change", + } + + expected_update = { + "id": 3, + "last_updated_at": datetime.datetime(2024, 3, 1, 0, 0), + "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), + "valid_to_timestamp": datetime.datetime(2024, 3, 2, 0, 0), + "value": None, + } + + result_insert = _get_result(target_df, "id = 3 and meta.valid_to_timestamp = '2999-12-31'") + result_update = _get_result(target_df, "id = 3 and meta.valid_from_timestamp = '1970-01-01'") + + assert spark.table(target_table).count() == 6 + assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 + assert result_insert == expected_insert + assert result_update == expected_update source_df4 = source_df3.where("id != 4") writer.df = source_df4 - writer.execute() + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() - assert spark.table(target_table).count() == 6 - assert spark.table(target_table).where("id = 4 and meta.valid_to_timestamp = '2999-12-31'").count() == 1 + assert spark.table(target_table).count() == 6 + assert spark.table(target_table).where("id = 4 and meta.valid_to_timestamp = '2999-12-31'").count() == 1 writer.orphaned_records_close = True - writer.execute() + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() - assert spark.table(target_table).count() == 6 - assert spark.table(target_table).where("id = 4 and meta.valid_to_timestamp = '2999-12-31'").count() == 0 + assert spark.table(target_table).count() == 6 + assert spark.table(target_table).where("id = 4 and meta.valid_to_timestamp = '2999-12-31'").count() == 0 source_df5 = source_df4.where("id != 5") writer.orphaned_records_close_ts = F.col("snapshot_at") writer.df = source_df5 - writer.execute() - expected = { - "id": 5, - "last_updated_at": datetime.datetime(2024, 5, 1, 0, 0), - "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), - "valid_to_timestamp": datetime.datetime(2024, 12, 31, 0, 0), - "value": "value-5", - } - result = _get_result(target_df, "id = 5") + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() + + expected = { + "id": 5, + "last_updated_at": datetime.datetime(2024, 5, 1, 0, 0), + "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), + "valid_to_timestamp": datetime.datetime(2024, 12, 31, 0, 0), + "value": "value-5", + } + result = _get_result(target_df, "id = 5") - assert spark.table(target_table).count() == 6 - assert spark.table(target_table).where("id = 5 and meta.valid_to_timestamp = '2999-12-31'").count() == 0 - assert result == expected + assert spark.table(target_table).count() == 6 + assert spark.table(target_table).where("id = 5 and meta.valid_to_timestamp = '2999-12-31'").count() == 0 + assert result == expected def test_scd2_logic(spark): from koheesio.spark.utils.connect import is_remote_session - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) - changes_data = [ [("key1", "value1", "scd1-value11", "2024-05-01"), ("key2", "value2", "scd1-value21", "2024-04-01")], [("key1", "value1_updated", "scd1-value12", "2024-05-02"), ("key3", "value3", "scd1-value31", "2024-05-03")], @@ -400,11 +433,17 @@ def test_scd2_logic(spark): changes_df = spark.createDataFrame(changes, ["merge_key", "value_scd2", "value_scd1", "run_date"]) changes_df = changes_df.withColumn("run_date", F.to_timestamp("run_date")) writer.df = changes_df - writer.execute() - res = ( - spark.sql("SELECT merge_key,value_scd2, value_scd1, _scd2.* FROM scd2_test_data_set") - .orderBy("merge_key", "effective_time") - .collect() - ) + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() + res = ( + spark.sql("SELECT merge_key,value_scd2, value_scd1, _scd2.* FROM scd2_test_data_set") + .orderBy("merge_key", "effective_time") + .collect() + ) - assert res == expected + assert res == expected diff --git a/tests/spark/writers/test_buffer.py b/tests/spark/writers/test_buffer.py index 6da783a..141bf57 100644 --- a/tests/spark/writers/test_buffer.py +++ b/tests/spark/writers/test_buffer.py @@ -1,5 +1,5 @@ -import gzip from datetime import datetime, timezone +import gzip from importlib.util import find_spec import pytest diff --git a/tests/spark/writers/test_sftp.py b/tests/spark/writers/test_sftp.py index 7119edd..a19e2fb 100644 --- a/tests/spark/writers/test_sftp.py +++ b/tests/spark/writers/test_sftp.py @@ -1,8 +1,8 @@ from unittest import mock import paramiko -import pytest from paramiko import SSHException +import pytest from koheesio.integrations.spark.sftp import ( SendCsvToSftp, diff --git a/tests/sso/test_okta.py b/tests/sso/test_okta.py index 5247493..8c7a548 100644 --- a/tests/sso/test_okta.py +++ b/tests/sso/test_okta.py @@ -1,5 +1,5 @@ -import logging from io import StringIO +import logging import pytest from requests_mock.mocker import Mocker diff --git a/tests/steps/test_steps.py b/tests/steps/test_steps.py index 92c563a..484f3b1 100644 --- a/tests/steps/test_steps.py +++ b/tests/steps/test_steps.py @@ -1,11 +1,11 @@ from __future__ import annotations -import io -import warnings from copy import deepcopy from functools import wraps +import io from unittest import mock from unittest.mock import call, patch +import warnings import pytest