Skip to content

Commit

Permalink
as good as I can make it
Browse files Browse the repository at this point in the history
  • Loading branch information
dannymeijer committed Oct 29, 2024
1 parent 1c742b4 commit b63dc7c
Show file tree
Hide file tree
Showing 55 changed files with 310 additions and 312 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -607,14 +607,14 @@ python_version = "3.10"
files = ["koheesio/**/*.py"]
plugins = ["pydantic.mypy"]
pretty = true
warn_unused_configs = true
check_untyped_defs = false
disallow_untyped_calls = false
disallow_untyped_defs = true
warn_unused_configs = true
warn_no_return = false
implicit_optional = true
allow_untyped_globals = true
disable_error_code = ["attr-defined", "return-value"]
disable_error_code = ["attr-defined", "return-value", "union-attr", "override"]

[tool.pylint.main]
fail-under = 9.5
Expand Down
54 changes: 31 additions & 23 deletions src/koheesio/integrations/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ class Box(Step, ABC):
description="Private key passphrase generated in the app management console.",
)

client: SkipValidation[Client] = None
client: SkipValidation[Client] = None # type: ignore

def init_client(self):
def init_client(self) -> None:
"""Set up the Box client."""
if not self.client:
self.client = Client(JWTAuth(**self.auth_options))

@property
def auth_options(self):
def auth_options(self) -> Dict[str, Any]:
"""
Get a dictionary of authentication options, that can be handily used in the child classes
"""
Expand All @@ -126,11 +126,11 @@ def auth_options(self):
"rsa_private_key_passphrase": self.rsa_private_key_passphrase.get_secret_value(),
}

def __init__(self, **data):
def __init__(self, **data: dict):
super().__init__(**data)
self.init_client()

def execute(self):
def execute(self) -> Step.Output: # type: ignore
# Plug to be able to unit test ABC
pass

Expand Down Expand Up @@ -167,7 +167,7 @@ class Output(StepOutput):
folder: Optional[Folder] = Field(default=None, description="Box folder object")

@model_validator(mode="after")
def validate_folder_or_path(self):
def validate_folder_or_path(self) -> "BoxFolderBase":
"""
Validations for 'folder' and 'path' parameter usage
"""
Expand All @@ -183,13 +183,13 @@ def validate_folder_or_path(self):
return self

@property
def _obj_from_id(self):
def _obj_from_id(self) -> Folder:
"""
Get folder object from identifier
"""
return self.client.folder(folder_id=self.folder).get() if isinstance(self.folder, str) else self.folder

def action(self):
def action(self) -> Optional[Folder]:
"""
Placeholder for 'action' method, that should be implemented in the child classes
Expand Down Expand Up @@ -223,7 +223,7 @@ class BoxFolderGet(BoxFolderBase):
False, description="Create sub-folders recursively if the path does not exist."
)

def _get_or_create_folder(self, current_folder_object, next_folder_name):
def _get_or_create_folder(self, current_folder_object: Folder, next_folder_name: str) -> Folder:
"""
Get or create a folder.
Expand All @@ -238,6 +238,11 @@ def _get_or_create_folder(self, current_folder_object, next_folder_name):
-------
next_folder_object: Folder
Next folder object.
Raises
------
BoxFolderNotFoundError
If the folder does not exist and 'create_sub_folders' is set to False.
"""
for item in current_folder_object.get_items():
if item.type == "folder" and item.name == next_folder_name:
Expand All @@ -251,7 +256,7 @@ def _get_or_create_folder(self, current_folder_object, next_folder_name):
"to create required directory structure automatically."
)

def action(self):
def action(self) -> Folder:
"""
Get folder action
Expand All @@ -267,7 +272,9 @@ def action(self):

if self.path:
cleaned_path_parts = [p for p in PurePath(self.path).parts if p.strip() not in [None, "", " ", "/"]]
current_folder_object = self.client.folder(folder_id=self.root) if isinstance(self.root, str) else self.root
current_folder_object: Union[Folder, str] = (
self.client.folder(folder_id=self.root) if isinstance(self.root, str) else self.root
)

for next_folder_name in cleaned_path_parts:
current_folder_object = self._get_or_create_folder(current_folder_object, next_folder_name)
Expand Down Expand Up @@ -295,7 +302,7 @@ class BoxFolderCreate(BoxFolderGet):
)

@field_validator("folder")
def validate_folder(cls, folder):
def validate_folder(cls, folder: Any) -> None:
"""
Validate 'folder' parameter
"""
Expand All @@ -322,7 +329,7 @@ class BoxFolderDelete(BoxFolderBase):
```
"""

def action(self):
def action(self) -> None:
"""
Delete folder action
Expand All @@ -345,7 +352,7 @@ class BoxReaderBase(Box, Reader, ABC):
"""

schema_: Optional[StructType] = Field(
None,
default=None,
alias="schema",
description="[Optional] Schema that will be applied during the creation of Spark DataFrame",
)
Expand Down Expand Up @@ -388,7 +395,7 @@ class BoxCsvFileReader(BoxReaderBase):

file: Union[str, list[str]] = Field(default=..., description="ID or list of IDs for the files to read.")

def execute(self):
def execute(self) -> BoxReaderBase.Output:
"""
Loop through the list of provided file identifiers and load data into dataframe.
For traceability purposes the following columns will be added to the dataframe:
Expand All @@ -409,6 +416,7 @@ def execute(self):
temp_df_pandas = pd.read_csv(data_buffer, header=0, dtype=str if not self.schema_ else None, **self.params) # type: ignore
temp_df = self.spark.createDataFrame(temp_df_pandas, schema=self.schema_)

# type: ignore
temp_df = (
temp_df
# fmt: off
Expand Down Expand Up @@ -450,9 +458,9 @@ class BoxCsvPathReader(BoxReaderBase):
"""

path: str = Field(default=..., description="Box path")
filter: Optional[str] = Field(default=r".csv|.txt$", description="[Optional] Regexp to filter folder contents")
filter: str = Field(default=r".csv|.txt$", description="[Optional] Regexp to filter folder contents")

def execute(self):
def execute(self) -> BoxReaderBase.Output:
"""
Identify the list of files from the source Box path that match desired filter and load them into Dataframe
"""
Expand Down Expand Up @@ -501,13 +509,13 @@ class BoxFileBase(Box):
)
path: Optional[str] = Field(default=None, description="Path to the Box folder, for example: `folder/sub-folder/lz")

def action(self, file: File, folder: Folder):
def action(self, file: File, folder: Folder) -> None:
"""
Abstract class for File level actions.
"""
raise NotImplementedError

def execute(self):
def execute(self) -> Box.Output:
"""
Generic execute method for all BoxToBox interactions. Deals with getting the correct folder and file objects
from various parameter inputs
Expand Down Expand Up @@ -541,7 +549,7 @@ class BoxToBoxFileCopy(BoxFileBase):
```
"""

def action(self, file: File, folder: Folder):
def action(self, file: File, folder: Folder) -> None:
"""
Copy file to the desired destination and extend file description with the processing info
Expand Down Expand Up @@ -577,7 +585,7 @@ class BoxToBoxFileMove(BoxFileBase):
```
"""

def action(self, file: File, folder: Folder):
def action(self, file: File, folder: Folder) -> None:
"""
Move file to the desired destination and extend file description with the processing info
Expand Down Expand Up @@ -632,15 +640,15 @@ class Output(StepOutput):
shared_link: str = Field(default=..., description="Shared link for the Box file")

@model_validator(mode="before")
def validate_name_for_binary_data(cls, values):
def validate_name_for_binary_data(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate 'file_name' parameter when providing a binary input for 'file'."""
file, file_name = values.get("file"), values.get("file_name")
if not isinstance(file, str) and not file_name:
raise AttributeError("The parameter 'file_name' is mandatory when providing a binary input for 'file'.")

return values

def action(self):
def action(self) -> None:
_file = self.file
_name = self.file_name

Expand Down
1 change: 0 additions & 1 deletion src/koheesio/integrations/spark/dq/spark_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pydantic import Field

import pyspark
from pyspark import sql

from koheesio.spark import DataFrame
from koheesio.spark.transformations import Transformation
Expand Down
36 changes: 18 additions & 18 deletions src/koheesio/integrations/spark/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ class SFTPWriteMode(str, Enum):
UPDATE = "update"

@classmethod
def from_string(cls, mode: str):
def from_string(cls, mode: str) -> "SFTPWriteMode":
"""Return the SFTPWriteMode for the given string."""
return cls[mode.upper()]

@property
def write_mode(self):
def write_mode(self) -> str:
"""Return the write mode for the given SFTPWriteMode."""
if self in {SFTPWriteMode.OVERWRITE, SFTPWriteMode.BACKUP, SFTPWriteMode.EXCLUSIVE, SFTPWriteMode.UPDATE}:
return "wb" # Overwrite, Backup, Exclusive, Update modes set the file to be written from the beginning
Expand Down Expand Up @@ -148,7 +148,7 @@ class SFTPWriter(Writer):

mode: SFTPWriteMode = Field(
default=SFTPWriteMode.OVERWRITE,
description="Write mode: overwrite, append, ignore, exclusive, backup, or update." + SFTPWriteMode.__doc__,
description="Write mode: overwrite, append, ignore, exclusive, backup, or update." + SFTPWriteMode.__doc__, # type: ignore
)

# private attrs
Expand Down Expand Up @@ -179,26 +179,26 @@ def validate_path_and_file_name(cls, data: dict) -> dict:
return data

@field_validator("host")
def validate_sftp_host(cls, v) -> str:
def validate_sftp_host(cls, host: str) -> str:
"""Validate the host"""
# remove the sftp:// prefix if present
if v.startswith("sftp://"):
v = v.replace("sftp://", "")
if host.startswith("sftp://"):
host = host.replace("sftp://", "")

# remove the trailing slash if present
if v.endswith("/"):
v = v[:-1]
if host.endswith("/"):
host = host[:-1]

return v
return host

@property
def write_mode(self):
def write_mode(self) -> str:
"""Return the write mode for the given SFTPWriteMode."""
mode = SFTPWriteMode.from_string(self.mode) # Convert string to SFTPWriteMode
return mode.write_mode

@property
def transport(self):
def transport(self) -> Transport:
"""Return the transport for the SFTP connection. If it doesn't exist, create it.
If the username and password are provided, use them to connect to the SFTP server.
Expand All @@ -224,14 +224,14 @@ def client(self) -> SFTPClient:
raise e
return self.__client__

def _close_client(self):
def _close_client(self) -> None:
"""Close the SFTP client and transport."""
if self.client:
self.client.close()
if self.transport:
self.transport.close()

def write_file(self, file_path: str, buffer_output: InstanceOf[BufferWriter.Output]):
def write_file(self, file_path: str, buffer_output: InstanceOf[BufferWriter.Output]) -> None:
"""
Using Paramiko, write the data in the buffer to SFTP.
"""
Expand Down Expand Up @@ -292,7 +292,7 @@ def _handle_write_mode(self, file_path: str, buffer_output: InstanceOf[BufferWri
# Then overwrite the file
self.write_file(file_path, buffer_output)

def execute(self):
def execute(self) -> Writer.Output:
buffer_output: InstanceOf[BufferWriter.Output] = self.buffer_writer.write(self.df)

# write buffer to the SFTP server
Expand Down Expand Up @@ -377,15 +377,15 @@ class SendCsvToSftp(PandasCsvBufferWriter, SFTPWriter):
For more details on the CSV parameters, refer to the PandasCsvBufferWriter class documentation.
"""

buffer_writer: PandasCsvBufferWriter = Field(default=None, validate_default=False)
buffer_writer: Optional[PandasCsvBufferWriter] = Field(default=None, validate_default=False)

@model_validator(mode="after")
def set_up_buffer_writer(self) -> "SendCsvToSftp":
"""Set up the buffer writer, passing all CSV related options to it."""
self.buffer_writer = PandasCsvBufferWriter(**self.get_options(options_type="koheesio_pandas_buffer_writer"))
return self

def execute(self):
def execute(self) -> SFTPWriter.Output:
SFTPWriter.execute(self)


Expand Down Expand Up @@ -459,7 +459,7 @@ class SendJsonToSftp(PandasJsonBufferWriter, SFTPWriter):
For more details on the JSON parameters, refer to the PandasJsonBufferWriter class documentation.
"""

buffer_writer: PandasJsonBufferWriter = Field(default=None, validate_default=False)
buffer_writer: Optional[PandasJsonBufferWriter] = Field(default=None, validate_default=False)

@model_validator(mode="after")
def set_up_buffer_writer(self) -> "SendJsonToSftp":
Expand All @@ -469,5 +469,5 @@ def set_up_buffer_writer(self) -> "SendJsonToSftp":
)
return self

def execute(self):
def execute(self) -> SFTPWriter.Output:
SFTPWriter.execute(self)
Loading

0 comments on commit b63dc7c

Please sign in to comment.