Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add studio auth to datachain #514

Merged
merged 9 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ dependencies = [
"psutil",
"huggingface_hub",
"iterative-telemetry>=0.0.9",
"platformdirs"
"platformdirs",
"dvc-studio-client>=0.21,<1"
]

[project.optional-dependencies]
Expand Down
91 changes: 91 additions & 0 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datachain import Session, utils
from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
from datachain.lib.dc import DataChain
from datachain.studio import process_studio_cli_args
from datachain.telemetry import telemetry

if TYPE_CHECKING:
Expand Down Expand Up @@ -97,6 +98,92 @@ def add_show_args(parser: ArgumentParser) -> None:
)


def add_studio_parser(subparsers, parent_parser) -> None:
studio_help = "Commands to authenticate Datachain with Iterative Studio"
studio_description = (
"Authenticate Datachain with Studio and set the token. "
"Once this token has been properly configured,\n"
"Datachain will utilize it for seamlessly sharing datasets\n"
"and using Studio features from CLI"
)

studio_parser = subparsers.add_parser(
"studio",
parents=[parent_parser],
description=studio_description,
help=studio_help,
)
studio_subparser = studio_parser.add_subparsers(
dest="cmd",
help="Use `Datachain studio CMD --help` to display command-specific help.",
required=True,
)

studio_login_help = "Authenticate Datachain with Studio host"
studio_login_description = (
"By default, this command authenticates the Datachain with Studio\n"
"using default scopes and assigns a random name as the token name."
)
login_parser = studio_subparser.add_parser(
"login",
parents=[parent_parser],
description=studio_login_description,
help=studio_login_help,
)

login_parser.add_argument(
"-H",
"--hostname",
action="store",
default=None,
help="The hostname of the Studio instance to authenticate with.",
)
login_parser.add_argument(
"-s",
"--scopes",
action="store",
default=None,
help="The scopes for the authentication token. ",
)

login_parser.add_argument(
"-n",
"--name",
action="store",
default=None,
help="The name of the authentication token. It will be used to\n"
"identify token shown in Studio profile.",
)

login_parser.add_argument(
"--no-open",
action="store_true",
default=False,
help="Use authentication flow based on user code.\n"
"You will be presented with user code to enter in browser.\n"
"Datachain will also use this if it cannot launch browser on your behalf.",
)

studio_logout_help = "Logout user from Studio"
studio_logout_description = "This removes the studio token from your global config."

studio_subparser.add_parser(
"logout",
parents=[parent_parser],
description=studio_logout_description,
help=studio_logout_help,
)

studio_token_help = "View the token datachain uses to contact Studio" # noqa: S105 # nosec B105

studio_subparser.add_parser(
"token",
parents=[parent_parser],
description=studio_token_help,
help=studio_token_help,
)


def get_parser() -> ArgumentParser: # noqa: PLR0915
try:
__version__ = version("datachain")
Expand Down Expand Up @@ -225,6 +312,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
help="Use a different filename for the resulting .edatachain file",
)

add_studio_parser(subp, parent_parser)

parse_pull = subp.add_parser(
"pull",
parents=[parent_parser],
Expand Down Expand Up @@ -1000,6 +1089,8 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
clear_cache(catalog)
elif args.command == "gc":
garbage_collect(catalog)
elif args.command == "studio":
process_studio_cli_args(args)
else:
print(f"invalid command: {args.command}", file=sys.stderr)
return 1
Expand Down
26 changes: 16 additions & 10 deletions src/datachain/config.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,49 @@
from collections.abc import Mapping
from contextlib import contextmanager
from enum import Enum
from typing import Optional, Union

from tomlkit import TOMLDocument, dump, load

from datachain.utils import DataChainDir, global_config_dir, system_config_dir


# Define an enum with value system, global and local
class ConfigLevel(Enum):
SYSTEM = "system"
GLOBAL = "global"
LOCAL = "local"


class Config:
SYSTEM_LEVELS = ("system", "global")
LOCAL_LEVELS = ("local",)
SYSTEM_LEVELS = (ConfigLevel.SYSTEM, ConfigLevel.GLOBAL)
LOCAL_LEVELS = (ConfigLevel.LOCAL,)

# In the order of precedence
LEVELS = SYSTEM_LEVELS + LOCAL_LEVELS

CONFIG = "config"

def __init__(
self,
level: Optional[str] = None,
level: Optional[ConfigLevel] = None,
):
self.level = level

self.init()

@classmethod
def get_dir(cls, level: Optional[str]) -> str:
if level == "system":
def get_dir(cls, level: Optional[ConfigLevel]) -> str:
if level == ConfigLevel.SYSTEM:
return system_config_dir()
if level == "global":
if level == ConfigLevel.GLOBAL:
return global_config_dir()

return DataChainDir.find().root
return str(DataChainDir.find().root)

def init(self):
d = DataChainDir(self.get_dir(self.level))
d.init()

def load_one(self, level: Optional[str] = None) -> TOMLDocument:
def load_one(self, level: Optional[ConfigLevel] = None) -> TOMLDocument:
config_path = DataChainDir(self.get_dir(level)).config

try:
Expand Down
97 changes: 97 additions & 0 deletions src/datachain/studio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
from typing import TYPE_CHECKING

from datachain.config import Config, ConfigLevel
from datachain.error import DataChainError
from datachain.utils import STUDIO_URL

if TYPE_CHECKING:
from argparse import Namespace

POST_LOGIN_MESSAGE = (
"Once you've logged in, return here "
"and you'll be ready to start using Datachain with Studio."
)


def process_studio_cli_args(args: "Namespace"):
if args.cmd == "login":
return login(args)
if args.cmd == "logout":
return logout()
if args.cmd == "token":
return token()
raise DataChainError(f"Unknown command '{args.cmd}'.")

Check warning on line 24 in src/datachain/studio.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/studio.py#L24

Added line #L24 was not covered by tests


def login(args: "Namespace"):
from dvc_studio_client.auth import StudioAuthError, get_access_token

config = Config().read().get("studio", {})
name = args.name
hostname = (
args.hostname
or os.environ.get("DVC_STUDIO_HOSTNAME")
or config.get("url")
or STUDIO_URL
)
scopes = args.scopes

if config.get("url", hostname) == hostname and "token" in config:
raise DataChainError(

Check warning on line 41 in src/datachain/studio.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/studio.py#L41

Added line #L41 was not covered by tests
"Token already exists. "
"To login with a different token, "
"logout using `datachain studio logout`."
)

open_browser = not args.no_open
try:
_, access_token = get_access_token(
token_name=name,
hostname=hostname,
scopes=scopes,
open_browser=open_browser,
client_name="Datachain",
post_login_message=POST_LOGIN_MESSAGE,
)
except StudioAuthError as exc:
raise DataChainError(f"Failed to authenticate with Studio: {exc}") from exc

config_path = save_config(hostname, access_token)
print(f"Authentication complete. Saved token to {config_path}.")
return 0


def logout():
with Config(ConfigLevel.GLOBAL).edit() as conf:
token = conf.get("studio", {}).get("token")
if not token:
raise DataChainError(
"Not logged in to Studio. Log in with 'datachain studio login'."
)

del conf["studio"]["token"]

print("Logged out from Studio. (you can log back in with 'datachain studio login')")


def token():
config = Config().read().get("studio", {})
token = config.get("token")
if not token:
raise DataChainError(
"Not logged in to Studio. Log in with 'datachain studio login'."
)

print(token)


def save_config(hostname, token):
config = Config(ConfigLevel.GLOBAL)
with config.edit() as conf:
studio_conf = conf.get("studio", {})
studio_conf["url"] = hostname
studio_conf["token"] = token
conf["studio"] = studio_conf

return config.config_file()
6 changes: 6 additions & 0 deletions src/datachain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,19 @@ def __init__(
if config is not None
else osp.join(self.root, self.CONFIG)
)
self.config = (
osp.abspath(config)
if config is not None
else osp.join(self.root, self.CONFIG)
)

def init(self):
os.makedirs(self.root, exist_ok=True)
os.makedirs(self.cache, exist_ok=True)
os.makedirs(self.tmp, exist_ok=True)
os.makedirs(osp.split(self.db)[0], exist_ok=True)
os.makedirs(osp.split(self.config)[0], exist_ok=True)
os.makedirs(osp.split(self.config)[0], exist_ok=True)

@classmethod
def default_root(cls) -> str:
Expand Down
20 changes: 19 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from datachain.dataset import DatasetRecord
from datachain.lib.dc import DataChain
from datachain.query.session import Session
from datachain.utils import DataChainDir
from datachain.utils import (
ENV_DATACHAIN_GLOBAL_CONFIG_DIR,
ENV_DATACHAIN_SYSTEM_CONFIG_DIR,
DataChainDir,
)

from .utils import DEFAULT_TREE, instantiate_tree

Expand All @@ -39,6 +43,20 @@ def add_test_env():
os.environ["DATACHAIN_TEST"] = "true"


@pytest.fixture(autouse=True)
def global_config_dir(monkeypatch, tmp_path_factory):
global_dir = str(tmp_path_factory.mktemp("studio-login-global"))
monkeypatch.setenv(ENV_DATACHAIN_GLOBAL_CONFIG_DIR, global_dir)
yield global_dir


@pytest.fixture(autouse=True)
def system_config_dir(monkeypatch, tmp_path_factory):
system_dir = str(tmp_path_factory.mktemp("studio-login-system"))
monkeypatch.setenv(ENV_DATACHAIN_SYSTEM_CONFIG_DIR, system_dir)
yield system_dir


@pytest.fixture(autouse=True)
def virtual_memory(mocker):
class VirtualMemory(NamedTuple):
Expand Down
Loading