-
Notifications
You must be signed in to change notification settings - Fork 131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consistently use the upload(path, IO)
and download(path) -> IO
across file-related operations
#148
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,13 @@ | |
|
||
import base64 | ||
import pathlib | ||
import shutil | ||
import sys | ||
from abc import ABC, abstractmethod | ||
from types import TracebackType | ||
from typing import TYPE_CHECKING, AnyStr, BinaryIO, Iterable, Iterator, Type | ||
|
||
from databricks.sdk.core import DatabricksError | ||
from databricks.sdk.core import ApiClient, DatabricksError | ||
|
||
from ..service import files | ||
|
||
|
@@ -313,6 +314,13 @@ class DbfsExt(files.DbfsAPI): | |
def open(self, path: str, *, read: bool = False, write: bool = False, overwrite: bool = False) -> _DbfsIO: | ||
return _DbfsIO(self, path, read=read, write=write, overwrite=overwrite) | ||
|
||
def upload(self, path: str, src: BinaryIO, *, overwrite: bool = False): | ||
with self.open(path, write=True, overwrite=overwrite) as dst: | ||
shutil.copyfileobj(src, dst) | ||
|
||
def download(self, path: str) -> BinaryIO: | ||
return self.open(path, read=True) | ||
|
||
def list(self, path: str, *, recursive=False) -> Iterator[files.FileInfo]: | ||
"""List directory contents or file details. | ||
|
||
|
@@ -385,3 +393,18 @@ def move_(self, src: str, dst: str, *, recursive=False, overwrite=False): | |
# do cross-fs moving | ||
self.copy(src, dst, recursive=recursive, overwrite=overwrite) | ||
source.delete(recursive=recursive) | ||
|
||
|
||
class FilesMixin: | ||
|
||
def __init__(self, api_client: ApiClient): | ||
self._api = api_client | ||
|
||
def upload(self, path: str, src: BinaryIO): | ||
self._api.do('PUT', f'/api/2.0/fs/files{path}', data=src) | ||
|
||
def download(self, path: str) -> BinaryIO: | ||
return self._api.do('GET', f'/api/2.0/fs/files{path}', raw=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is path always absolute? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
|
||
def delete(self, path: str): | ||
self._api.do('DELETE', f'/api/2.0/fs/files{path}') |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we include some information in the User Agent to indicate that these requests are coming from the Databricks SDK? This will help for tracking on our side. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, let me share you the spec |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from typing import BinaryIO, Iterator, Optional | ||
|
||
from ..core import DatabricksError | ||
from ..service.workspace import (ExportFormat, Language, ObjectInfo, | ||
ObjectType, WorkspaceAPI) | ||
|
||
|
||
def _fqcn(x: any) -> str: | ||
return f'{x.__module__}.{x.__name__}' | ||
|
||
|
||
class WorkspaceExt(WorkspaceAPI): | ||
|
||
def list(self, | ||
path: str, | ||
*, | ||
notebooks_modified_after: Optional[int] = None, | ||
recursive: Optional[bool] = False, | ||
**kwargs) -> Iterator[ObjectInfo]: | ||
parent_list = super().list | ||
queue = [path] | ||
while queue: | ||
path, queue = queue[0], queue[1:] | ||
for object_info in parent_list(path, notebooks_modified_after=notebooks_modified_after): | ||
if recursive and object_info.object_type == ObjectType.DIRECTORY: | ||
queue.append(object_info.path) | ||
continue | ||
yield object_info | ||
|
||
def upload(self, | ||
path: str, | ||
content: BinaryIO, | ||
*, | ||
format: Optional[ExportFormat] = None, | ||
language: Optional[Language] = None, | ||
overwrite: Optional[bool] = False) -> None: | ||
""" | ||
Uploads a workspace object (for example, a notebook or file) or the contents of an entire | ||
directory (`DBC` format). | ||
|
||
Errors: | ||
* `RESOURCE_ALREADY_EXISTS`: if `path` already exists no `overwrite=True`. | ||
* `INVALID_PARAMETER_VALUE`: if `format` and `content` values are not compatible. | ||
|
||
:param path: target location of the file on workspace. | ||
:param content: file-like `io.BinaryIO` of the `path` contents. | ||
:param format: By default, `ExportFormat.SOURCE`. If using `ExportFormat.AUTO` the `path` | ||
is imported or exported as either a workspace file or a notebook, depending | ||
on an analysis of the `item`’s extension and the header content provided in | ||
the request. In addition, if the `path` is imported as a notebook, then | ||
the `item`’s extension is automatically removed. | ||
:param language: Only required if using `ExportFormat.SOURCE`. | ||
""" | ||
if format is not None and not isinstance(format, ExportFormat): | ||
raise ValueError( | ||
f'format is expected to be {_fqcn(ExportFormat)}, but got {_fqcn(format.__class__)}') | ||
if (not format or format == ExportFormat.SOURCE) and not language: | ||
suffixes = { | ||
'.py': Language.PYTHON, | ||
'.sql': Language.SQL, | ||
'.scala': Language.SCALA, | ||
'.R': Language.R | ||
} | ||
for sfx, lang in suffixes.items(): | ||
if path.endswith(sfx): | ||
language = lang | ||
break | ||
if language is not None and not isinstance(language, Language): | ||
raise ValueError( | ||
f'language is expected to be {_fqcn(Language)}, but got {_fqcn(language.__class__)}') | ||
data = {'path': path} | ||
if format: data['format'] = format.value | ||
if language: data['language'] = language.value | ||
if overwrite: data['overwrite'] = 'true' | ||
try: | ||
return self._api.do('POST', '/api/2.0/workspace/import', files={'content': content}, data=data) | ||
except DatabricksError as e: | ||
if e.error_code == 'INVALID_PARAMETER_VALUE': | ||
msg = f'Perhaps you forgot to specify the `format=ExportFormat.AUTO`. {e}' | ||
raise DatabricksError(message=msg, error_code=e.error_code) | ||
else: | ||
raise e | ||
|
||
def download(self, path: str, *, format: Optional[ExportFormat] = None) -> BinaryIO: | ||
""" | ||
Downloads notebook or file from the workspace | ||
|
||
:param path: location of the file or notebook on workspace. | ||
:param format: By default, `ExportFormat.SOURCE`. If using `ExportFormat.AUTO` the `path` | ||
is imported or exported as either a workspace file or a notebook, depending | ||
on an analysis of the `item`’s extension and the header content provided in | ||
the request. | ||
:return: file-like `io.BinaryIO` of the `path` contents. | ||
""" | ||
query = {'path': path, 'direct_download': 'true'} | ||
if format: query['format'] = format.value | ||
return self._api.do('GET', '/api/2.0/workspace/export', query=query, raw=True) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import io | ||
|
||
from databricks.sdk.service.workspace import ExportFormat, Language | ||
|
||
|
||
def test_workspace_recursive_list(w, random): | ||
names = [] | ||
for i in w.workspace.list(f'/Users/{w.current_user.me().user_name}', recursive=True): | ||
names.append(i.path) | ||
assert len(names) > 0 | ||
|
||
|
||
def test_workspace_upload_download_notebooks(w, random): | ||
notebook = f'/Users/{w.current_user.me().user_name}/notebook-{random(12)}.py' | ||
|
||
w.workspace.upload(notebook, io.BytesIO(b'print(1)')) | ||
with w.workspace.download(notebook) as f: | ||
content = f.read() | ||
assert content == b'# Databricks notebook source\nprint(1)' | ||
|
||
w.workspace.delete(notebook) | ||
|
||
|
||
def test_workspace_upload_download_files(w, random): | ||
py_file = f'/Users/{w.current_user.me().user_name}/file-{random(12)}.py' | ||
|
||
w.workspace.upload(py_file, io.BytesIO(b'print(1)'), format=ExportFormat.AUTO) | ||
with w.workspace.download(py_file) as f: | ||
content = f.read() | ||
assert content == b'print(1)' | ||
|
||
w.workspace.delete(py_file) | ||
|
||
|
||
def test_workspace_upload_download_txt_files(w, random): | ||
txt_file = f'/Users/{w.current_user.me().user_name}/txt-{random(12)}.txt' | ||
|
||
w.workspace.upload(txt_file, io.BytesIO(b'print(1)'), format=ExportFormat.AUTO) | ||
with w.workspace.download(txt_file) as f: | ||
content = f.read() | ||
assert content == b'print(1)' | ||
|
||
w.workspace.delete(txt_file) | ||
|
||
|
||
def test_workspace_upload_download_notebooks_no_extension(w, random): | ||
nb = f'/Users/{w.current_user.me().user_name}/notebook-{random(12)}' | ||
|
||
w.workspace.upload(nb, io.BytesIO(b'print(1)'), format=ExportFormat.SOURCE, language=Language.PYTHON) | ||
with w.workspace.download(nb) as f: | ||
content = f.read() | ||
assert content == b'# Databricks notebook source\nprint(1)' | ||
|
||
w.workspace.delete(nb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be added here explicitly because there is no Files tag in the OpenAPI spec?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, temporarily