Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: notebook_login() does not update UI on Databricks #1414

Merged
merged 15 commits into from
Apr 4, 2023
19 changes: 14 additions & 5 deletions src/huggingface_hub/_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .hf_api import HfApi
from .utils import (
HfFolder,
capture_output,
is_google_colab,
is_notebook,
list_credential_helpers,
Expand Down Expand Up @@ -180,7 +181,7 @@ def notebook_login() -> None:
"""
try:
import ipywidgets.widgets as widgets # type: ignore
from IPython.display import clear_output, display # type: ignore
from IPython.display import display # type: ignore
except ImportError:
raise ImportError(
"The `notebook_login` function can only be used in a notebook (Jupyter or"
Expand Down Expand Up @@ -211,8 +212,16 @@ def login_token_event(t):
add_to_git_credential = git_checkbox_widget.value
# Erase token and clear value to make sure it's not saved in the notebook.
token_widget.value = ""
clear_output()
_login(token, add_to_git_credential=add_to_git_credential)
# Hide inputs
login_token_widget.children = [widgets.Label("Connecting...")]
try:
with capture_output() as captured:
_login(token, add_to_git_credential=add_to_git_credential)
message = captured.getvalue()
except Exception as error:
message = str(error)
# Print result (success message or error)
login_token_widget.children = [widgets.Label(line) for line in message.split("\n") if line.strip()]

token_finish_button.on_click(login_token_event)

Expand All @@ -235,13 +244,13 @@ def _login(token: str, add_to_git_credential: bool) -> None:
set_git_credential(token)
print(
"Your token has been saved in your configured git credential helpers"
f" ({','.join(list_credential_helpers())})."
+ f" ({','.join(list_credential_helpers())})."
)
else:
print("Token has not been saved to git credential helper.")

HfFolder.save_token(token)
print("Your token has been saved to", HfFolder.path_token)
print(f"Your token has been saved to {HfFolder.path_token}")
print("Login successful")


Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
is_tf_available,
is_torch_available,
)
from ._subprocess import run_interactive_subprocess, run_subprocess
from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess
from ._validators import (
HFValidationError,
smoothly_deprecate_use_auth_token,
Expand Down
22 changes: 22 additions & 0 deletions src/huggingface_hub/utils/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
"""Contains utilities to easily handle subprocesses in `huggingface_hub`."""
import os
import subprocess
import sys
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from typing import IO, Generator, List, Optional, Tuple, Union

Expand All @@ -26,6 +28,26 @@
logger = get_logger(__name__)


@contextmanager
def capture_output() -> Generator[StringIO, None, None]:
"""Capture output that is printed to terminal.

Taken from https://stackoverflow.com/a/34738440

Example:
```py
>>> with capture_output() as output:
... print("hello world")
>>> assert output.getvalue() == "hello world\n"
```
"""
output = StringIO()
previous_output = sys.stdout
sys.stdout = output
yield output
sys.stdout = previous_output


def run_subprocess(
command: Union[str, List[str]],
folder: Optional[Union[str, Path]] = None,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_command_delete_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
_manual_review_no_tui,
_read_manual_review_tmp_file,
)
from huggingface_hub.utils import SoftTemporaryDirectory
from huggingface_hub.utils import SoftTemporaryDirectory, capture_output

from .testing_utils import capture_output, handle_injection
from .testing_utils import handle_injection


class TestDeleteCacheHelpers(unittest.TestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_utils_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from huggingface_hub._snapshot_download import snapshot_download
from huggingface_hub.commands.scan_cache import ScanCacheCommand
from huggingface_hub.utils import DeleteCacheStrategy, HFCacheInfo, scan_cache_dir
from huggingface_hub.utils import DeleteCacheStrategy, HFCacheInfo, capture_output, scan_cache_dir
from huggingface_hub.utils._cache_manager import (
CacheNotFound,
_format_size,
Expand All @@ -19,7 +19,6 @@

from .testing_constants import TOKEN
from .testing_utils import (
capture_output,
rmtree_with_retry,
with_production_testing,
xfail_on_windows,
Expand Down
28 changes: 1 addition & 27 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
import os
import shutil
import stat
import sys
import time
import unittest
import uuid
from contextlib import contextmanager
from enum import Enum
from functools import wraps
from io import StringIO
from pathlib import Path
from typing import Callable, Generator, Optional, Type, TypeVar, Union
from typing import Callable, Optional, Type, TypeVar, Union
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -321,30 +319,6 @@ def _inner_decorator(test_function: Callable) -> Callable:
return _inner_decorator


@contextmanager
def capture_output() -> Generator[StringIO, None, None]:
"""Capture output that is printed to console.

Especially useful to test CLI commands.

Taken from https://stackoverflow.com/a/34738440

Example:
```py
class TestHelloWorld(unittest.TestCase):
def test_hello_world(self):
with capture_output() as output:
print("hello world")
self.assertEqual(output.getvalue(), "hello world\n")
```
"""
output = StringIO()
previous_output = sys.stdout
sys.stdout = output
yield output
sys.stdout = previous_output


T = TypeVar("T")


Expand Down