Skip to content

Commit

Permalink
Fix test and import
Browse files Browse the repository at this point in the history
  • Loading branch information
maddiedawson committed Aug 22, 2023
1 parent 8653b41 commit a51619f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ exclude = [
"docs/source/conf.py"
]

# Disable checks for missing imports, as a conditional install of streaming will not include them
# Any incorrect imports will be discovered through test cases
reportMissingImports="none"
reportUnnecessaryIsInstance = "warning"
reportMissingTypeStubs = "none"
reportIncompatibleMethodOverride = "none"
Expand Down
11 changes: 7 additions & 4 deletions streaming/base/storage/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,16 +618,19 @@ def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False) -> None:
super().__init__(out, keep_local, progress_bar)
self.client = self._create_workspace_client()
self.dbfs_path = self.remote.lstrip('dbfs:') # pyright: ignore
self.check_folder_exists()

def _create_workspace_client(self):
try:
from databricks.sdk import WorkspaceClient
except ImportError as e:
e.msg = get_import_exception_message(e.name) # pyright: ignore
raise e

super().__init__(out, keep_local, progress_bar)
self.client = WorkspaceClient()
self.dbfs_path = self.remote.lstrip('dbfs:') # pyright: ignore
self.check_folder_exists()
return WorkspaceClient()

def upload_file(self, filename: str):
"""Upload file from local instance to DBFS. Does not overwrite.
Expand Down
13 changes: 10 additions & 3 deletions tests/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,19 +319,26 @@ def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]):

class TestDBFSUploader:

@patch('streaming.base.storage.upload.DBFSUploader._create_workspace_client')
@pytest.mark.parametrize('out', ['dbfs:/container/dir', ('./dir1', 'dbfs:/container/dir/')])
def test_instantiation(self, out: Any):
def test_instantiation(self, mock_create_client: Mock, out: Any):
mock_create_client.side_effect = None
_ = DBFSUploader(out=out)
if not isinstance(out, str):
shutil.rmtree(out[0], ignore_errors=True)

@patch('streaming.base.storage.upload.DBFSUploader._create_workspace_client')
@pytest.mark.parametrize('out', ['ss4://bucket/dir', ('./dir1', 'gcs://bucket/dir/')])
def test_invalid_remote_list(self, out: Any):
def test_invalid_remote_list(self, mock_create_client: Mock, out: Any):
mock_create_client.side_effect = None
with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'):
_ = DBFSUploader(out=out)

def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]):
@patch('streaming.base.storage.upload.DBFSUploader._create_workspace_client')
def test_local_directory_is_empty(self, mock_create_client: Mock,
local_remote_dir: Tuple[str, str]):
with pytest.raises(FileExistsError, match=f'Directory is not empty.*'):
mock_create_client.side_effect = None
local, _ = local_remote_dir
os.makedirs(local, exist_ok=True)
local_file_path = os.path.join(local, 'file.txt')
Expand Down

0 comments on commit a51619f

Please sign in to comment.