Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support tag management (create, list, delete) in CLI #2172

Merged
merged 20 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/huggingface_hub/commands/huggingface_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from huggingface_hub.commands.env import EnvironmentCommand
from huggingface_hub.commands.lfs import LfsCommands
from huggingface_hub.commands.scan_cache import ScanCacheCommand
from huggingface_hub.commands.tag import TagCommands
from huggingface_hub.commands.upload import UploadCommand
from huggingface_hub.commands.user import UserCommands

Expand All @@ -36,6 +37,7 @@ def main():
LfsCommands.register_subcommand(commands_parser)
ScanCacheCommand.register_subcommand(commands_parser)
DeleteCacheCommand.register_subcommand(commands_parser)
TagCommands.register_subcommand(commands_parser)

# Let's go
args = parser.parse_args()
Expand Down
194 changes: 194 additions & 0 deletions src/huggingface_hub/commands/tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# coding=utf-8
# Copyright 2023-present, the HuggingFace Inc. team.
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains commands to perform tag management with the CLI.

Usage Examples:
- Create a tag:
$ huggingface-cli tag user/my-model 1.0 --message "First release"
$ huggingface-cli tag user/my-model 1.0 -m "First release" --revision develop
$ huggingface-cli tag user/my-dataset 1.0 -m "First release" --repo-type dataset
- List all tags:
$ huggingface-cli tag -l user/my-model
$ huggingface-cli tag --list user/my-dataset --repo-type dataset
- Delete a tag:
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
$ huggingface-cli tag -d user/my-model 1.0
$ huggingface-cli tag --delete user/my-dataset 1.0 --repo-type dataset
"""

import subprocess
from argparse import Namespace, _SubParsersAction

from requests.exceptions import HTTPError

from huggingface_hub.commands import BaseHuggingfaceCLICommand
from huggingface_hub.constants import (
REPO_TYPES,
)
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils import get_token

from ..utils import HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
from ._cli_utils import ANSI


class TagCommands(BaseHuggingfaceCLICommand):
@staticmethod
def register_subcommand(parser: _SubParsersAction):
tag_parser = parser.add_parser("tag", help="(create, list, delete) tags for a model in the hub")

tag_parser.add_argument(
"repo_id", type=str, help="The repository (model, dataset, or space) for the operation."
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
)
tag_parser.add_argument("tag", nargs="?", type=str, help="The name of the tag for creation or deletion.")
tag_parser.add_argument("-m", "--message", type=str, help="The description of the tag to create.")
tag_parser.add_argument("--revision", type=str, help="The git revision to tag.")
tag_parser.add_argument("--token", type=str, help="Authentication token.")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
tag_parser.add_argument(
"--repo-type",
choices=["model", "dataset", "space"],
default="model",
help="Set the type of repository (model, dataset, or space).",
)
tag_parser.add_argument("-y", "--yes", action="store_true", help="Answer Yes to prompts automatically.")
tag_parser.add_argument("--force", action="store_true", help="Force tag creation or deletion.")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

tag_parser.add_argument("-l", "--list", action="store_true", help="List tags for a repository.")
tag_parser.add_argument("-d", "--delete", action="store_true", help="Delete a tag for a repository.")

tag_parser.set_defaults(func=lambda args: handle_commands(args))


def handle_commands(args: Namespace):
if args.list:
return TagListCommand(args)
elif args.delete:
return TagDeleteCommand(args)
else:
return TagCreateCommand(args)


class TagCommand:
def __init__(self, args: Namespace):
self.args = args
self._api = HfApi()
self.token = self.args.token if self.args.token is not None else get_token()
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
self.user = self._api.whoami(self.token)["name"]
self.repo_id = self.args.repo_id
self.check_git_installed()
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
self.check_fields()

def check_git_installed(self):
try:
stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
print(ANSI.gray(stdout.strip()))
except FileNotFoundError:
print("Looks like you do not have git installed, please install.")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

def check_fields(self):
self.token = self.args.token if self.args.token is not None else get_token()
self.user = self._api.whoami(self.token)["name"]
if self.args.repo_type not in REPO_TYPES:
print("Invalid repo --repo-type")
exit(1)
if self.token is None:
print("Not logged in")
exit(1)
Wauplin marked this conversation as resolved.
Show resolved Hide resolved


class TagCreateCommand(TagCommand):
def run(self):
if self.args.message is None:
print("Tag message cannot be empty. Please provide a message with `--tag-message`.")
exit(1)
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

print(f"You are about to create tag {ANSI.bold(self.args.tag)} on {ANSI.bold(self.repo_id)}")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

if not self.args.yes:
choice = input("Proceed? [Y/n] ").lower()
if not (choice == "" or choice == "y" or choice == "yes"):
print("Abort")
exit()
try:
self._api.create_tag(
repo_id=self.repo_id,
tag=self.args.tag,
tag_message=self.args.message,
revision=self.args.revision,
token=self.token,
exist_ok=self.args.force,
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
repo_type=self.args.repo_type,
)
except RepositoryNotFoundError:
print(f"Repository {ANSI.bold(self.repo_id)} not found.")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
exit(1)
except RevisionNotFoundError:
print(f"Revision {ANSI.bold(self.args.revision)} not found.")
exit(1)
except HfHubHTTPError as e:
if e.response.status_code == 409:
print(f"Tag {ANSI.bold(self.args.tag)} already exists on {ANSI.bold(self.repo_id)}")
print("Use `--force` to overwrite the existing tag.")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
exit(1)
raise e

print(f"Tag {ANSI.bold(self.args.tag)} created on {ANSI.bold(self.repo_id)}")
print("")


class TagListCommand(TagCommand):
def run(self):
try:
refs = self._api.list_repo_refs(
repo_id=self.repo_id,
repo_type=self.args.repo_type,
)
except RepositoryNotFoundError:
print(f"Repository {ANSI.bold(self.repo_id)} not found.")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
exit(1)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
if len(refs.tags) == 0:
print(" No tags found")
exit(0)
print("\nYour tags:")
for tag in refs.tags:
print(f" {ANSI.bold(tag.name)}")
print("")


class TagDeleteCommand(TagCommand):
def run(self):
print(f"You are about to delete tag {ANSI.bold(self.args.tag)} on {ANSI.bold(self.repo_id)}")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

if not self.args.yes:
choice = input("Proceed? [Y/n] ").lower()
if not (choice == "" or choice == "y" or choice == "yes"):
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
print("Abort")
exit()
try:
self._api.delete_tag(
repo_id=self.repo_id, tag=self.args.tag, token=self.token, repo_type=self.args.repo_type
)
except RepositoryNotFoundError:
print(f"Repository {ANSI.bold(self.repo_id)} not found.")
exit(1)
except RevisionNotFoundError:
print(f"Tag {ANSI.bold(self.args.tag)} not found on {ANSI.bold(self.repo_id)}")
exit(1)
print(f"Tag {ANSI.bold(self.args.tag)} deleted on {ANSI.bold(self.repo_id)}")
print("")
62 changes: 62 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from huggingface_hub.commands.delete_cache import DeleteCacheCommand
from huggingface_hub.commands.download import DownloadCommand
from huggingface_hub.commands.scan_cache import ScanCacheCommand
from huggingface_hub.commands.tag import TagCommands
from huggingface_hub.commands.upload import UploadCommand
from huggingface_hub.utils import RevisionNotFoundError, SoftTemporaryDirectory, capture_output

Expand Down Expand Up @@ -560,6 +561,67 @@ def test_download_with_ignored_patterns(self, mock: Mock) -> None:
DownloadCommand(args).run()


class TestTagCommands(unittest.TestCase):
def setUp(self) -> None:
"""
Set up CLI as in `src/huggingface_hub/commands/huggingface_cli.py`.
"""
self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli <command> [<args>]")
commands_parser = self.parser.add_subparsers()
TagCommands.register_subcommand(commands_parser)

def test_tag_create_basic(self) -> None:
args = self.parser.parse_args(["tag", DUMMY_MODEL_ID, "1.0", "-m", "My tag message"])
self.assertEqual(args.repo_id, DUMMY_MODEL_ID)
self.assertEqual(args.tag, "1.0")
self.assertIsNotNone(args.message)
self.assertIsNone(args.revision)
self.assertIsNone(args.token)
self.assertEqual(args.repo_type, "model")
self.assertFalse(args.yes)

def test_tag_create_with_all_options(self) -> None:
args = self.parser.parse_args(
[
"tag",
DUMMY_MODEL_ID,
"1.0",
"--message",
"My tag message",
"--revision",
"v1.0.0",
"--token",
"my-token",
"--repo-type",
"dataset",
"--yes",
"--force",
]
)
self.assertEqual(args.repo_id, DUMMY_MODEL_ID)
self.assertEqual(args.tag, "1.0")
self.assertEqual(args.message, "My tag message")
self.assertEqual(args.revision, "v1.0.0")
self.assertEqual(args.token, "my-token")
self.assertEqual(args.repo_type, "dataset")
self.assertTrue(args.yes)
self.assertTrue(args.force)

def test_tag_list_basic(self) -> None:
args = self.parser.parse_args(["tag", "--list", DUMMY_MODEL_ID])
self.assertEqual(args.repo_id, DUMMY_MODEL_ID)
self.assertIsNone(args.token)
self.assertEqual(args.repo_type, "model")

def test_tag_delete_basic(self) -> None:
args = self.parser.parse_args(["tag", "--delete", DUMMY_MODEL_ID, "1.0"])
self.assertEqual(args.repo_id, DUMMY_MODEL_ID)
self.assertEqual(args.tag, "1.0")
self.assertIsNone(args.token)
self.assertEqual(args.repo_type, "model")
self.assertFalse(args.yes)


@contextmanager
def tmp_current_directory() -> Generator[str, None, None]:
"""Change current directory to a tmp dir and revert back when exiting."""
Expand Down
Loading