Skip to content

Commit

Permalink
Fix: notebook_login() does not update UI on Databricks (#1414)
Browse files Browse the repository at this point in the history
* wip

* catch errors to display

* change output

* display error

* change error output

* refactor

* support multiple string objects

* fix strings and add typing

* black

* style

* Use capture_output in notebook_login

* Fix messages

* fix tests

---------

Co-authored-by: Lucain Pouget <lucainp@gmail.com>
  • Loading branch information
fwetdb and Wauplin authored Apr 4, 2023
1 parent d751269 commit 7a3e1e8
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 37 deletions.
19 changes: 14 additions & 5 deletions src/huggingface_hub/_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .hf_api import HfApi
from .utils import (
HfFolder,
capture_output,
is_google_colab,
is_notebook,
list_credential_helpers,
Expand Down Expand Up @@ -180,7 +181,7 @@ def notebook_login() -> None:
"""
try:
import ipywidgets.widgets as widgets # type: ignore
from IPython.display import clear_output, display # type: ignore
from IPython.display import display # type: ignore
except ImportError:
raise ImportError(
"The `notebook_login` function can only be used in a notebook (Jupyter or"
Expand Down Expand Up @@ -211,8 +212,16 @@ def login_token_event(t):
add_to_git_credential = git_checkbox_widget.value
# Erase token and clear value to make sure it's not saved in the notebook.
token_widget.value = ""
clear_output()
_login(token, add_to_git_credential=add_to_git_credential)
# Hide inputs
login_token_widget.children = [widgets.Label("Connecting...")]
try:
with capture_output() as captured:
_login(token, add_to_git_credential=add_to_git_credential)
message = captured.getvalue()
except Exception as error:
message = str(error)
# Print result (success message or error)
login_token_widget.children = [widgets.Label(line) for line in message.split("\n") if line.strip()]

token_finish_button.on_click(login_token_event)

Expand All @@ -235,13 +244,13 @@ def _login(token: str, add_to_git_credential: bool) -> None:
set_git_credential(token)
print(
"Your token has been saved in your configured git credential helpers"
f" ({','.join(list_credential_helpers())})."
+ f" ({','.join(list_credential_helpers())})."
)
else:
print("Token has not been saved to git credential helper.")

HfFolder.save_token(token)
print("Your token has been saved to", HfFolder.path_token)
print(f"Your token has been saved to {HfFolder.path_token}")
print("Login successful")


Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
is_tf_available,
is_torch_available,
)
from ._subprocess import run_interactive_subprocess, run_subprocess
from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess
from ._validators import (
HFValidationError,
smoothly_deprecate_use_auth_token,
Expand Down
22 changes: 22 additions & 0 deletions src/huggingface_hub/utils/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
"""Contains utilities to easily handle subprocesses in `huggingface_hub`."""
import os
import subprocess
import sys
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from typing import IO, Generator, List, Optional, Tuple, Union

Expand All @@ -26,6 +28,26 @@
logger = get_logger(__name__)


@contextmanager
def capture_output() -> Generator[StringIO, None, None]:
"""Capture output that is printed to terminal.
Taken from https://stackoverflow.com/a/34738440
Example:
```py
>>> with capture_output() as output:
... print("hello world")
>>> assert output.getvalue() == "hello world\n"
```
"""
output = StringIO()
previous_output = sys.stdout
sys.stdout = output
yield output
sys.stdout = previous_output


def run_subprocess(
command: Union[str, List[str]],
folder: Optional[Union[str, Path]] = None,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_command_delete_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
_manual_review_no_tui,
_read_manual_review_tmp_file,
)
from huggingface_hub.utils import SoftTemporaryDirectory
from huggingface_hub.utils import SoftTemporaryDirectory, capture_output

from .testing_utils import capture_output, handle_injection
from .testing_utils import handle_injection


class TestDeleteCacheHelpers(unittest.TestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_utils_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from huggingface_hub._snapshot_download import snapshot_download
from huggingface_hub.commands.scan_cache import ScanCacheCommand
from huggingface_hub.utils import DeleteCacheStrategy, HFCacheInfo, scan_cache_dir
from huggingface_hub.utils import DeleteCacheStrategy, HFCacheInfo, capture_output, scan_cache_dir
from huggingface_hub.utils._cache_manager import (
CacheNotFound,
_format_size,
Expand All @@ -19,7 +19,6 @@

from .testing_constants import TOKEN
from .testing_utils import (
capture_output,
rmtree_with_retry,
with_production_testing,
xfail_on_windows,
Expand Down
28 changes: 1 addition & 27 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
import os
import shutil
import stat
import sys
import time
import unittest
import uuid
from contextlib import contextmanager
from enum import Enum
from functools import wraps
from io import StringIO
from pathlib import Path
from typing import Callable, Generator, Optional, Type, TypeVar, Union
from typing import Callable, Optional, Type, TypeVar, Union
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -321,30 +319,6 @@ def _inner_decorator(test_function: Callable) -> Callable:
return _inner_decorator


@contextmanager
def capture_output() -> Generator[StringIO, None, None]:
"""Capture output that is printed to console.
Especially useful to test CLI commands.
Taken from https://stackoverflow.com/a/34738440
Example:
```py
class TestHelloWorld(unittest.TestCase):
def test_hello_world(self):
with capture_output() as output:
print("hello world")
self.assertEqual(output.getvalue(), "hello world\n")
```
"""
output = StringIO()
previous_output = sys.stdout
sys.stdout = output
yield output
sys.stdout = previous_output


T = TypeVar("T")


Expand Down

0 comments on commit 7a3e1e8

Please sign in to comment.