Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistently use the upload(path, IO) and download(path) -> IO across file-related operations #148

Merged
merged 3 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .codegen/__init__.py.tmpl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import databricks.sdk.core as client
import databricks.sdk.dbutils as dbutils

from databricks.sdk.mixins.dbfs import DbfsExt
from databricks.sdk.mixins.files import DbfsExt, FilesMixin
from databricks.sdk.mixins.compute import ClustersExt
from databricks.sdk.mixins.workspace import WorkspaceExt
{{- range .Services}}
from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}

Expand All @@ -11,7 +12,7 @@ from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}
"azure_client_id" "azure_tenant_id" "azure_environment" "auth_type" "cluster_id"}}

{{- define "api" -}}
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" -}}
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" -}}
{{- $genApi := concat .PascalName "API" -}}
{{- getOrDefault $mixins $genApi $genApi -}}
{{- end -}}
Expand All @@ -34,6 +35,7 @@ class WorkspaceClient:
self.config = config
self.dbutils = dbutils.RemoteDbUtils(self.config)
self.api_client = client.ApiClient(self.config)
self.files = FilesMixin(self.api_client)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be added here explicitly because there is no Files tag in the OpenAPI spec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, temporarily

{{- range .Services}}{{if not .IsAccounts}}
self.{{.SnakeName}} = {{template "api" .}}(self.api_client){{end -}}{{end}}

Expand Down
6 changes: 4 additions & 2 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 21 additions & 6 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,20 +825,32 @@ def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest
r.headers[k] = v
return r

def do(self, method: str, path: str, query: dict = None, body: dict = None) -> dict:
def do(self,
method: str,
path: str,
query: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None) -> dict:
headers = {'Accept': 'application/json', 'User-Agent': self._user_agent_base}
response = self._session.request(method,
f"{self._cfg.host}{path}",
params=query,
json=body,
headers=headers)
headers=headers,
files=files,
data=data,
stream=True if raw else False)
try:
self._record_request_log(response)
self._record_request_log(response, raw=raw or data is not None or files is not None)
if not response.ok:
# TODO: experiment with traceback pruning for better readability
# See https://stackoverflow.com/a/58821552/277035
payload = response.json()
raise self._make_nicer_error(status_code=response.status_code, **payload) from None
if raw:
return response.raw
nfx marked this conversation as resolved.
Show resolved Hide resolved
if not len(response.content):
return {}
return response.json()
Expand Down Expand Up @@ -867,7 +879,7 @@ def _make_nicer_error(self, status_code: int = 200, **kwargs) -> DatabricksError
kwargs['message'] = message
return DatabricksError(**kwargs)

def _record_request_log(self, response: requests.Response):
def _record_request_log(self, response: requests.Response, raw=False):
if not logger.isEnabledFor(logging.DEBUG):
return
request = response.request
Expand All @@ -882,9 +894,12 @@ def _record_request_log(self, response: requests.Response):
for k, v in request.headers.items():
sb.append(f'> * {k}: {self._only_n_bytes(v, self._debug_truncate_bytes)}')
if request.body:
sb.append(self._redacted_dump("> ", request.body))
sb.append("> [raw stream]" if raw else self._redacted_dump("> ", request.body))
sb.append(f'< {response.status_code} {response.reason}')
if response.content:
if raw and response.headers.get('Content-Type', None) != 'application/json':
# Raw streams with `Transfer-Encoding: chunked` do not have `Content-Type` header
sb.append("< [raw stream]")
elif response.content:
sb.append(self._redacted_dump("< ", response.content))
logger.debug("\n".join(sb))

Expand Down
2 changes: 1 addition & 1 deletion databricks/sdk/dbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .core import ApiClient, Config, DatabricksError
from .mixins import compute as compute_ext
from .mixins import dbfs as dbfs_ext
from .mixins import files as dbfs_ext
from .service import compute, workspace

_LOG = logging.getLogger('databricks.sdk')
Expand Down
25 changes: 24 additions & 1 deletion databricks/sdk/mixins/dbfs.py → databricks/sdk/mixins/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import base64
import pathlib
import shutil
import sys
from abc import ABC, abstractmethod
from types import TracebackType
from typing import TYPE_CHECKING, AnyStr, BinaryIO, Iterable, Iterator, Type

from databricks.sdk.core import DatabricksError
from databricks.sdk.core import ApiClient, DatabricksError

from ..service import files

Expand Down Expand Up @@ -313,6 +314,13 @@ class DbfsExt(files.DbfsAPI):
def open(self, path: str, *, read: bool = False, write: bool = False, overwrite: bool = False) -> _DbfsIO:
return _DbfsIO(self, path, read=read, write=write, overwrite=overwrite)

def upload(self, path: str, src: BinaryIO, *, overwrite: bool = False):
with self.open(path, write=True, overwrite=overwrite) as dst:
shutil.copyfileobj(src, dst)

def download(self, path: str) -> BinaryIO:
return self.open(path, read=True)

def list(self, path: str, *, recursive=False) -> Iterator[files.FileInfo]:
"""List directory contents or file details.

Expand Down Expand Up @@ -385,3 +393,18 @@ def move_(self, src: str, dst: str, *, recursive=False, overwrite=False):
# do cross-fs moving
self.copy(src, dst, recursive=recursive, overwrite=overwrite)
source.delete(recursive=recursive)


class FilesMixin:

def __init__(self, api_client: ApiClient):
self._api = api_client

def upload(self, path: str, src: BinaryIO):
self._api.do('PUT', f'/api/2.0/fs/files{path}', data=src)

def download(self, path: str) -> BinaryIO:
return self._api.do('GET', f'/api/2.0/fs/files{path}', raw=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is path always absolute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


def delete(self, path: str):
self._api.do('DELETE', f'/api/2.0/fs/files{path}')
97 changes: 97 additions & 0 deletions databricks/sdk/mixins/workspace.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we include some information in the User Agent to indicate that these requests are coming from the Databricks SDK? This will help for tracking on our side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, let me share you the spec

Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import BinaryIO, Iterator, Optional

from ..core import DatabricksError
from ..service.workspace import (ExportFormat, Language, ObjectInfo,
ObjectType, WorkspaceAPI)


def _fqcn(x: any) -> str:
return f'{x.__module__}.{x.__name__}'


class WorkspaceExt(WorkspaceAPI):

def list(self,
path: str,
*,
notebooks_modified_after: Optional[int] = None,
recursive: Optional[bool] = False,
**kwargs) -> Iterator[ObjectInfo]:
parent_list = super().list
queue = [path]
while queue:
path, queue = queue[0], queue[1:]
for object_info in parent_list(path, notebooks_modified_after=notebooks_modified_after):
if recursive and object_info.object_type == ObjectType.DIRECTORY:
queue.append(object_info.path)
continue
yield object_info

def upload(self,
path: str,
content: BinaryIO,
*,
format: Optional[ExportFormat] = None,
language: Optional[Language] = None,
overwrite: Optional[bool] = False) -> None:
"""
Uploads a workspace object (for example, a notebook or file) or the contents of an entire
directory (`DBC` format).

Errors:
* `RESOURCE_ALREADY_EXISTS`: if `path` already exists no `overwrite=True`.
* `INVALID_PARAMETER_VALUE`: if `format` and `content` values are not compatible.

:param path: target location of the file on workspace.
:param content: file-like `io.BinaryIO` of the `path` contents.
:param format: By default, `ExportFormat.SOURCE`. If using `ExportFormat.AUTO` the `path`
is imported or exported as either a workspace file or a notebook, depending
on an analysis of the `item`’s extension and the header content provided in
the request. In addition, if the `path` is imported as a notebook, then
the `item`’s extension is automatically removed.
:param language: Only required if using `ExportFormat.SOURCE`.
"""
if format is not None and not isinstance(format, ExportFormat):
raise ValueError(
f'format is expected to be {_fqcn(ExportFormat)}, but got {_fqcn(format.__class__)}')
if (not format or format == ExportFormat.SOURCE) and not language:
suffixes = {
'.py': Language.PYTHON,
'.sql': Language.SQL,
'.scala': Language.SCALA,
'.R': Language.R
}
for sfx, lang in suffixes.items():
if path.endswith(sfx):
language = lang
break
if language is not None and not isinstance(language, Language):
raise ValueError(
f'language is expected to be {_fqcn(Language)}, but got {_fqcn(language.__class__)}')
data = {'path': path}
if format: data['format'] = format.value
if language: data['language'] = language.value
if overwrite: data['overwrite'] = 'true'
try:
return self._api.do('POST', '/api/2.0/workspace/import', files={'content': content}, data=data)
except DatabricksError as e:
if e.error_code == 'INVALID_PARAMETER_VALUE':
msg = f'Perhaps you forgot to specify the `format=ExportFormat.AUTO`. {e}'
raise DatabricksError(message=msg, error_code=e.error_code)
else:
raise e

def download(self, path: str, *, format: Optional[ExportFormat] = None) -> BinaryIO:
"""
Downloads notebook or file from the workspace

:param path: location of the file or notebook on workspace.
:param format: By default, `ExportFormat.SOURCE`. If using `ExportFormat.AUTO` the `path`
is imported or exported as either a workspace file or a notebook, depending
on an analysis of the `item`’s extension and the header content provided in
the request.
:return: file-like `io.BinaryIO` of the `path` contents.
"""
query = {'path': path, 'direct_download': 'true'}
if format: query['format'] = format.value
return self._api.do('GET', '/api/2.0/workspace/export', query=query, raw=True)
1 change: 1 addition & 0 deletions databricks/sdk/service/sql.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/integration/test_dbfs.py → tests/integration/test_files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import pathlib
from typing import List

Expand Down Expand Up @@ -170,3 +171,25 @@ def test_move_from_dbfs_to_local(w, random, junk, tmp_path):
assert f.read() == payload_02
with (tmp_path / root.name / 'a/b/03').open('rb') as f:
assert f.read() == payload_03


def test_dbfs_upload_download(w, random, junk, tmp_path):
root = pathlib.Path(f'/tmp/{random()}')

f = io.BytesIO(b"some text data")
w.dbfs.upload(f'{root}/01', f)

with w.dbfs.download(f'{root}/01') as f:
assert f.read() == b"some text data"


def test_files_api_upload_download(w, random):
pytest.skip()
f = io.BytesIO(b"some text data")
target_file = f'/Volumes/bogdanghita/default/v3_shared/sdk-testing/{random(10)}.txt'
w.files.upload(target_file, f)

with w.files.download(target_file) as f:
assert f.read() == b"some text data"

w.files.delete(target_file)
54 changes: 54 additions & 0 deletions tests/integration/test_workspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import io

from databricks.sdk.service.workspace import ExportFormat, Language


def test_workspace_recursive_list(w, random):
names = []
for i in w.workspace.list(f'/Users/{w.current_user.me().user_name}', recursive=True):
names.append(i.path)
assert len(names) > 0


def test_workspace_upload_download_notebooks(w, random):
notebook = f'/Users/{w.current_user.me().user_name}/notebook-{random(12)}.py'

w.workspace.upload(notebook, io.BytesIO(b'print(1)'))
with w.workspace.download(notebook) as f:
content = f.read()
assert content == b'# Databricks notebook source\nprint(1)'

w.workspace.delete(notebook)


def test_workspace_upload_download_files(w, random):
py_file = f'/Users/{w.current_user.me().user_name}/file-{random(12)}.py'

w.workspace.upload(py_file, io.BytesIO(b'print(1)'), format=ExportFormat.AUTO)
with w.workspace.download(py_file) as f:
content = f.read()
assert content == b'print(1)'

w.workspace.delete(py_file)


def test_workspace_upload_download_txt_files(w, random):
txt_file = f'/Users/{w.current_user.me().user_name}/txt-{random(12)}.txt'

w.workspace.upload(txt_file, io.BytesIO(b'print(1)'), format=ExportFormat.AUTO)
with w.workspace.download(txt_file) as f:
content = f.read()
assert content == b'print(1)'

w.workspace.delete(txt_file)


def test_workspace_upload_download_notebooks_no_extension(w, random):
nb = f'/Users/{w.current_user.me().user_name}/notebook-{random(12)}'

w.workspace.upload(nb, io.BytesIO(b'print(1)'), format=ExportFormat.SOURCE, language=Language.PYTHON)
with w.workspace.download(nb) as f:
content = f.read()
assert content == b'# Databricks notebook source\nprint(1)'

w.workspace.delete(nb)
8 changes: 4 additions & 4 deletions tests/test_dbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def dbutils(config):


def test_fs_cp(dbutils, mocker):
inner = mocker.patch('databricks.sdk.mixins.dbfs.DbfsExt.copy')
inner = mocker.patch('databricks.sdk.mixins.files.DbfsExt.copy')

dbutils.fs.cp('a', 'b', recurse=True)

Expand All @@ -30,7 +30,7 @@ def test_fs_head(dbutils, mocker):

def test_fs_ls(dbutils, mocker):
from databricks.sdk.service.files import FileInfo
inner = mocker.patch('databricks.sdk.mixins.dbfs.DbfsExt.list',
inner = mocker.patch('databricks.sdk.mixins.files.DbfsExt.list',
return_value=[
FileInfo(path='b', file_size=10, modification_time=20),
FileInfo(path='c', file_size=30, modification_time=40),
Expand All @@ -53,7 +53,7 @@ def test_fs_mkdirs(dbutils, mocker):


def test_fs_mv(dbutils, mocker):
inner = mocker.patch('databricks.sdk.mixins.dbfs.DbfsExt.move_')
inner = mocker.patch('databricks.sdk.mixins.files.DbfsExt.move_')

dbutils.fs.mv('a', 'b')

Expand All @@ -75,7 +75,7 @@ def write(self, contents):
self._written = contents

mock_open = _MockOpen()
inner = mocker.patch('databricks.sdk.mixins.dbfs.DbfsExt.open', return_value=mock_open)
inner = mocker.patch('databricks.sdk.mixins.files.DbfsExt.open', return_value=mock_open)

dbutils.fs.put('a', 'b')

Expand Down