Skip to content

Commit

Permalink
Consistently use the upload(path, IO) and download(path) -> IO ac…
Browse files Browse the repository at this point in the history
…ross file-related operations (#148)

## Changes

- added `w.workspace.upload` & `w.workspace.download`
- added `w.dbfs.upload` & `w.dbfs.download`
- added `w.files.upload` & `w.files.download`
- modified low-level client to work with raw streams and debug messages
correctly

Fix #104

## Tests
new integration tests
  • Loading branch information
nfx authored Jun 8, 2023
1 parent 038e48a commit 087cf3f
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 16 deletions.
6 changes: 4 additions & 2 deletions .codegen/__init__.py.tmpl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import databricks.sdk.core as client
import databricks.sdk.dbutils as dbutils

from databricks.sdk.mixins.dbfs import DbfsExt
from databricks.sdk.mixins.files import DbfsExt, FilesMixin
from databricks.sdk.mixins.compute import ClustersExt
from databricks.sdk.mixins.workspace import WorkspaceExt
{{- range .Services}}
from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}

Expand All @@ -11,7 +12,7 @@ from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}
"azure_client_id" "azure_tenant_id" "azure_environment" "auth_type" "cluster_id"}}

{{- define "api" -}}
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" -}}
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" -}}
{{- $genApi := concat .PascalName "API" -}}
{{- getOrDefault $mixins $genApi $genApi -}}
{{- end -}}
Expand All @@ -34,6 +35,7 @@ class WorkspaceClient:
self.config = config
self.dbutils = dbutils.RemoteDbUtils(self.config)
self.api_client = client.ApiClient(self.config)
self.files = FilesMixin(self.api_client)
{{- range .Services}}{{if not .IsAccounts}}
self.{{.SnakeName}} = {{template "api" .}}(self.api_client){{end -}}{{end}}

Expand Down
6 changes: 4 additions & 2 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 21 additions & 6 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,20 +852,32 @@ def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest
r.headers[k] = v
return r

def do(self, method: str, path: str, query: dict = None, body: dict = None) -> dict:
def do(self,
method: str,
path: str,
query: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None) -> dict:
headers = {'Accept': 'application/json', 'User-Agent': self._user_agent_base}
response = self._session.request(method,
f"{self._cfg.host}{path}",
params=query,
json=body,
headers=headers)
headers=headers,
files=files,
data=data,
stream=True if raw else False)
try:
self._record_request_log(response)
self._record_request_log(response, raw=raw or data is not None or files is not None)
if not response.ok:
# TODO: experiment with traceback pruning for better readability
# See https://stackoverflow.com/a/58821552/277035
payload = response.json()
raise self._make_nicer_error(status_code=response.status_code, **payload) from None
if raw:
return response.raw
if not len(response.content):
return {}
return response.json()
Expand Down Expand Up @@ -894,7 +906,7 @@ def _make_nicer_error(self, status_code: int = 200, **kwargs) -> DatabricksError
kwargs['message'] = message
return DatabricksError(**kwargs)

def _record_request_log(self, response: requests.Response):
def _record_request_log(self, response: requests.Response, raw=False):
if not logger.isEnabledFor(logging.DEBUG):
return
request = response.request
Expand All @@ -909,9 +921,12 @@ def _record_request_log(self, response: requests.Response):
for k, v in request.headers.items():
sb.append(f'> * {k}: {self._only_n_bytes(v, self._debug_truncate_bytes)}')
if request.body:
sb.append(self._redacted_dump("> ", request.body))
sb.append("> [raw stream]" if raw else self._redacted_dump("> ", request.body))
sb.append(f'< {response.status_code} {response.reason}')
if response.content:
if raw and response.headers.get('Content-Type', None) != 'application/json':
# Raw streams with `Transfer-Encoding: chunked` do not have `Content-Type` header
sb.append("< [raw stream]")
elif response.content:
sb.append(self._redacted_dump("< ", response.content))
logger.debug("\n".join(sb))

Expand Down
2 changes: 1 addition & 1 deletion databricks/sdk/dbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .core import ApiClient, Config, DatabricksError
from .mixins import compute as compute_ext
from .mixins import dbfs as dbfs_ext
from .mixins import files as dbfs_ext
from .service import compute, workspace

_LOG = logging.getLogger('databricks.sdk')
Expand Down
25 changes: 24 additions & 1 deletion databricks/sdk/mixins/dbfs.py → databricks/sdk/mixins/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import base64
import pathlib
import shutil
import sys
from abc import ABC, abstractmethod
from types import TracebackType
from typing import TYPE_CHECKING, AnyStr, BinaryIO, Iterable, Iterator, Type

from databricks.sdk.core import DatabricksError
from databricks.sdk.core import ApiClient, DatabricksError

from ..service import files

Expand Down Expand Up @@ -313,6 +314,13 @@ class DbfsExt(files.DbfsAPI):
def open(self, path: str, *, read: bool = False, write: bool = False, overwrite: bool = False) -> _DbfsIO:
return _DbfsIO(self, path, read=read, write=write, overwrite=overwrite)

def upload(self, path: str, src: BinaryIO, *, overwrite: bool = False):
with self.open(path, write=True, overwrite=overwrite) as dst:
shutil.copyfileobj(src, dst)

def download(self, path: str) -> BinaryIO:
return self.open(path, read=True)

def list(self, path: str, *, recursive=False) -> Iterator[files.FileInfo]:
"""List directory contents or file details.
Expand Down Expand Up @@ -385,3 +393,18 @@ def move_(self, src: str, dst: str, *, recursive=False, overwrite=False):
# do cross-fs moving
self.copy(src, dst, recursive=recursive, overwrite=overwrite)
source.delete(recursive=recursive)


class FilesMixin:

def __init__(self, api_client: ApiClient):
self._api = api_client

def upload(self, path: str, src: BinaryIO):
self._api.do('PUT', f'/api/2.0/fs/files{path}', data=src)

def download(self, path: str) -> BinaryIO:
return self._api.do('GET', f'/api/2.0/fs/files{path}', raw=True)

def delete(self, path: str):
self._api.do('DELETE', f'/api/2.0/fs/files{path}')
97 changes: 97 additions & 0 deletions databricks/sdk/mixins/workspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import BinaryIO, Iterator, Optional

from ..core import DatabricksError
from ..service.workspace import (ExportFormat, Language, ObjectInfo,
ObjectType, WorkspaceAPI)


def _fqcn(x: any) -> str:
return f'{x.__module__}.{x.__name__}'


class WorkspaceExt(WorkspaceAPI):

def list(self,
path: str,
*,
notebooks_modified_after: Optional[int] = None,
recursive: Optional[bool] = False,
**kwargs) -> Iterator[ObjectInfo]:
parent_list = super().list
queue = [path]
while queue:
path, queue = queue[0], queue[1:]
for object_info in parent_list(path, notebooks_modified_after=notebooks_modified_after):
if recursive and object_info.object_type == ObjectType.DIRECTORY:
queue.append(object_info.path)
continue
yield object_info

def upload(self,
path: str,
content: BinaryIO,
*,
format: Optional[ExportFormat] = None,
language: Optional[Language] = None,
overwrite: Optional[bool] = False) -> None:
"""
Uploads a workspace object (for example, a notebook or file) or the contents of an entire
directory (`DBC` format).
Errors:
* `RESOURCE_ALREADY_EXISTS`: if `path` already exists no `overwrite=True`.
* `INVALID_PARAMETER_VALUE`: if `format` and `content` values are not compatible.
:param path: target location of the file on workspace.
:param content: file-like `io.BinaryIO` of the `path` contents.
:param format: By default, `ExportFormat.SOURCE`. If using `ExportFormat.AUTO` the `path`
is imported or exported as either a workspace file or a notebook, depending
on an analysis of the `item`’s extension and the header content provided in
the request. In addition, if the `path` is imported as a notebook, then
the `item`’s extension is automatically removed.
:param language: Only required if using `ExportFormat.SOURCE`.
"""
if format is not None and not isinstance(format, ExportFormat):
raise ValueError(
f'format is expected to be {_fqcn(ExportFormat)}, but got {_fqcn(format.__class__)}')
if (not format or format == ExportFormat.SOURCE) and not language:
suffixes = {
'.py': Language.PYTHON,
'.sql': Language.SQL,
'.scala': Language.SCALA,
'.R': Language.R
}
for sfx, lang in suffixes.items():
if path.endswith(sfx):
language = lang
break
if language is not None and not isinstance(language, Language):
raise ValueError(
f'language is expected to be {_fqcn(Language)}, but got {_fqcn(language.__class__)}')
data = {'path': path}
if format: data['format'] = format.value
if language: data['language'] = language.value
if overwrite: data['overwrite'] = 'true'
try:
return self._api.do('POST', '/api/2.0/workspace/import', files={'content': content}, data=data)
except DatabricksError as e:
if e.error_code == 'INVALID_PARAMETER_VALUE':
msg = f'Perhaps you forgot to specify the `format=ExportFormat.AUTO`. {e}'
raise DatabricksError(message=msg, error_code=e.error_code)
else:
raise e

def download(self, path: str, *, format: Optional[ExportFormat] = None) -> BinaryIO:
"""
Downloads notebook or file from the workspace
:param path: location of the file or notebook on workspace.
:param format: By default, `ExportFormat.SOURCE`. If using `ExportFormat.AUTO` the `path`
is imported or exported as either a workspace file or a notebook, depending
on an analysis of the `item`’s extension and the header content provided in
the request.
:return: file-like `io.BinaryIO` of the `path` contents.
"""
query = {'path': path, 'direct_download': 'true'}
if format: query['format'] = format.value
return self._api.do('GET', '/api/2.0/workspace/export', query=query, raw=True)
1 change: 1 addition & 0 deletions databricks/sdk/service/sql.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/integration/test_dbfs.py → tests/integration/test_files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import pathlib
from typing import List

Expand Down Expand Up @@ -170,3 +171,25 @@ def test_move_from_dbfs_to_local(w, random, junk, tmp_path):
assert f.read() == payload_02
with (tmp_path / root.name / 'a/b/03').open('rb') as f:
assert f.read() == payload_03


def test_dbfs_upload_download(w, random, junk, tmp_path):
root = pathlib.Path(f'/tmp/{random()}')

f = io.BytesIO(b"some text data")
w.dbfs.upload(f'{root}/01', f)

with w.dbfs.download(f'{root}/01') as f:
assert f.read() == b"some text data"


def test_files_api_upload_download(w, random):
pytest.skip()
f = io.BytesIO(b"some text data")
target_file = f'/Volumes/bogdanghita/default/v3_shared/sdk-testing/{random(10)}.txt'
w.files.upload(target_file, f)

with w.files.download(target_file) as f:
assert f.read() == b"some text data"

w.files.delete(target_file)
54 changes: 54 additions & 0 deletions tests/integration/test_workspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import io

from databricks.sdk.service.workspace import ExportFormat, Language


def test_workspace_recursive_list(w, random):
names = []
for i in w.workspace.list(f'/Users/{w.current_user.me().user_name}', recursive=True):
names.append(i.path)
assert len(names) > 0


def test_workspace_upload_download_notebooks(w, random):
notebook = f'/Users/{w.current_user.me().user_name}/notebook-{random(12)}.py'

w.workspace.upload(notebook, io.BytesIO(b'print(1)'))
with w.workspace.download(notebook) as f:
content = f.read()
assert content == b'# Databricks notebook source\nprint(1)'

w.workspace.delete(notebook)


def test_workspace_upload_download_files(w, random):
py_file = f'/Users/{w.current_user.me().user_name}/file-{random(12)}.py'

w.workspace.upload(py_file, io.BytesIO(b'print(1)'), format=ExportFormat.AUTO)
with w.workspace.download(py_file) as f:
content = f.read()
assert content == b'print(1)'

w.workspace.delete(py_file)


def test_workspace_upload_download_txt_files(w, random):
txt_file = f'/Users/{w.current_user.me().user_name}/txt-{random(12)}.txt'

w.workspace.upload(txt_file, io.BytesIO(b'print(1)'), format=ExportFormat.AUTO)
with w.workspace.download(txt_file) as f:
content = f.read()
assert content == b'print(1)'

w.workspace.delete(txt_file)


def test_workspace_upload_download_notebooks_no_extension(w, random):
nb = f'/Users/{w.current_user.me().user_name}/notebook-{random(12)}'

w.workspace.upload(nb, io.BytesIO(b'print(1)'), format=ExportFormat.SOURCE, language=Language.PYTHON)
with w.workspace.download(nb) as f:
content = f.read()
assert content == b'# Databricks notebook source\nprint(1)'

w.workspace.delete(nb)
8 changes: 4 additions & 4 deletions tests/test_dbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def dbutils(config):


def test_fs_cp(dbutils, mocker):
inner = mocker.patch('databricks.sdk.mixins.dbfs.DbfsExt.copy')
inner = mocker.patch('databricks.sdk.mixins.files.DbfsExt.copy')

dbutils.fs.cp('a', 'b', recurse=True)

Expand All @@ -30,7 +30,7 @@ def test_fs_head(dbutils, mocker):

def test_fs_ls(dbutils, mocker):
from databricks.sdk.service.files import FileInfo
inner = mocker.patch('databricks.sdk.mixins.dbfs.DbfsExt.list',
inner = mocker.patch('databricks.sdk.mixins.files.DbfsExt.list',
return_value=[
FileInfo(path='b', file_size=10, modification_time=20),
FileInfo(path='c', file_size=30, modification_time=40),
Expand All @@ -53,7 +53,7 @@ def test_fs_mkdirs(dbutils, mocker):


def test_fs_mv(dbutils, mocker):
inner = mocker.patch('databricks.sdk.mixins.dbfs.DbfsExt.move_')
inner = mocker.patch('databricks.sdk.mixins.files.DbfsExt.move_')

dbutils.fs.mv('a', 'b')

Expand All @@ -75,7 +75,7 @@ def write(self, contents):
self._written = contents

mock_open = _MockOpen()
inner = mocker.patch('databricks.sdk.mixins.dbfs.DbfsExt.open', return_value=mock_open)
inner = mocker.patch('databricks.sdk.mixins.files.DbfsExt.open', return_value=mock_open)

dbutils.fs.put('a', 'b')

Expand Down

0 comments on commit 087cf3f

Please sign in to comment.