Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datacustomcode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
from datacustomcode.io.writer.print import PrintDataCloudWriter

__all__ = ["Client", "QueryAPIDataCloudReader", "PrintDataCloudWriter"]
__all__ = ["Client", "PrintDataCloudWriter", "QueryAPIDataCloudReader"]
10 changes: 5 additions & 5 deletions src/datacustomcode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
# This lets all readers and writers to be findable via config
from datacustomcode.io import * # noqa: F403
from datacustomcode.io.base import BaseDataAccessLayer
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001

DEFAULT_CONFIG_NAME = "config.yaml"


if TYPE_CHECKING:
from pyspark.sql import SparkSession
from datacustomcode.io.reader.base import BaseDataCloudReader
from datacustomcode.io.writer.base import BaseDataCloudWriter


class ForceableConfig(BaseModel):
Expand Down Expand Up @@ -72,7 +72,7 @@ class AccessLayerObjectConfig(ForceableConfig, Generic[_T]):

def to_object(self, spark: SparkSession) -> _T:
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
return cast(_T, type_(spark=spark, **self.options))
return cast("_T", type_(spark=spark, **self.options))


class SparkConfig(ForceableConfig):
Expand All @@ -90,8 +90,8 @@ class SparkConfig(ForceableConfig):


class ClientConfig(BaseModel):
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
reader_config: Union[AccessLayerObjectConfig["BaseDataCloudReader"], None] = None
writer_config: Union[AccessLayerObjectConfig["BaseDataCloudWriter"], None] = None
spark_config: Union[SparkConfig, None] = None

def update(self, other: ClientConfig) -> ClientConfig:
Expand Down
12 changes: 8 additions & 4 deletions src/datacustomcode/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import shutil
import sys
import tempfile
import time
from typing import (
Expand Down Expand Up @@ -163,8 +164,12 @@ def prepare_dependency_archive(directory: str) -> None:

with tempfile.TemporaryDirectory() as temp_dir:
logger.info("Building dependencies archive")
shutil.copy("requirements.txt", temp_dir)
shutil.copy("build_native_dependencies.sh", temp_dir)
try:
shutil.copy("requirements.txt", temp_dir)
shutil.copy("build_native_dependencies.sh", temp_dir)
except FileNotFoundError as e:
logger.error(f"Error copying files: {e}")
sys.exit(1)
cmd = (
f"{PLATFORM_ENV_VAR} docker run --rm "
f"-v {temp_dir}:/workspace "
Expand Down Expand Up @@ -373,8 +378,7 @@ def zip(
for file in files:
if file != ".DS_Store":
file_path = os.path.join(root, file)
zipf.write(file_path)

zipf.write(file_path, arcname=file)
logger.debug(f"Created zip file: {ZIP_FILE_NAME}")


Expand Down
31 changes: 23 additions & 8 deletions tests/test_deploy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the deploy module."""

import shutil
from unittest.mock import (
MagicMock,
call,
Expand Down Expand Up @@ -218,12 +219,11 @@ def test_prepare_dependency_archive_docker_run_failure(
mock_cmd_output.assert_any_call(self.EXPECTED_DOCKER_RUN_CMD)

@patch("datacustomcode.deploy.cmd_output")
@patch("datacustomcode.deploy.shutil.copy")
@patch("datacustomcode.deploy.tempfile.TemporaryDirectory")
@patch("datacustomcode.deploy.os.path.join")
@patch("datacustomcode.deploy.os.makedirs")
def test_prepare_dependency_archive_file_copy_failure(
self, mock_makedirs, mock_join, mock_temp_dir, mock_copy, mock_cmd_output
self, mock_makedirs, mock_join, mock_temp_dir, mock_cmd_output
):
"""Test prepare_dependency_archive when file copy fails."""
# Mock the temporary directory context manager
Expand All @@ -235,17 +235,32 @@ def test_prepare_dependency_archive_file_copy_failure(
# Mock cmd_output to return image ID
mock_cmd_output.return_value = "abc123"

# Mock shutil.copy to raise exception
mock_copy.side_effect = FileNotFoundError("File not found")
# Mock os.path.join for archive path
mock_join.return_value = "/tmp/test_dir/native_dependencies.tar.gz"

with pytest.raises(FileNotFoundError, match="File not found"):
prepare_dependency_archive("/test/dir")
# Create a custom mock for shutil.copy that raises FileNotFoundError
# only for the specific calls we want to test
original_copy = shutil.copy

def mock_copy(src, dst):
if src == "requirements.txt" or src == "build_native_dependencies.sh":
raise FileNotFoundError("File not found")
return original_copy(src, dst)

with patch("datacustomcode.deploy.shutil.copy", side_effect=mock_copy):
# Call the function - it should catch the FileNotFoundError and call sys.exit(1)
# We expect it to raise SystemExit (which is what sys.exit(1) does)
with pytest.raises(SystemExit) as exc_info:
prepare_dependency_archive("/test/dir")

# Verify the exit code is 1
assert exc_info.value.code == 1

# Verify docker images command was called
mock_cmd_output.assert_any_call(self.EXPECTED_DOCKER_IMAGES_CMD)

# Verify files were attempted to be copied
mock_copy.assert_any_call("requirements.txt", "/tmp/test_dir")
# Verify files were attempted to be copied (the mock will have been called)
# Note: We can't easily verify the mock calls since we're using a custom function


class TestHasNonemptyRequirementsFile:
Expand Down
Loading