Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement / Print “accessing as <user>” in client #493

Merged
merged 11 commits into from
Jul 2, 2024
56 changes: 46 additions & 10 deletions dagshub/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
DagshubAuthenticator,
)
from dagshub.common import config
from dagshub.common.helpers import http_request
from dagshub.common.helpers import http_request, log_message
from dagshub.common.util import multi_urljoin

logger = logging.getLogger(__name__)
Expand All @@ -43,6 +43,7 @@ def __init__(self, cache_location: str = None, **kwargs):
self._known_good_tokens: Dict[str, Set[DagshubTokenABC]] = {}

self.__token_access_lock = threading.RLock()
self._accessing_as_was_printed = False

@property
def _token_cache(self):
Expand Down Expand Up @@ -126,6 +127,11 @@ def get_token_object(self, host: str = None, fail_if_no_token: bool = False, **k

host = host or config.host
if host == config.host and config.token is not None:
user = TokenStorage.get_username_of_token(config.token, host)
if user is not None:
self._print_accessing_as(user)
else:
raise RuntimeError("Provided DagsHub token is not valid")
return EnvVarDagshubToken(config.token, host)

with self._token_access_lock:
Expand All @@ -135,6 +141,7 @@ def get_token_object(self, host: str = None, fail_if_no_token: bool = False, **k
self._known_good_tokens[host] = set()
good_token_set = self._known_good_tokens[host]
good_token = None
good_user = None
token_queue = list(sorted(tokens, key=lambda t: t.priority))

for token in token_queue:
Expand All @@ -146,9 +153,11 @@ def get_token_object(self, host: str = None, fail_if_no_token: bool = False, **k
good_token = token
break
# Check token validity
elif self.is_valid_token(token, host):
user = TokenStorage.get_username_of_token(token, host)
if user is not None:
good_token = token
good_token_set.add(token)
good_user = user
# Remove invalid token from the cache
else:
self.invalidate_token(token, host)
Expand All @@ -173,6 +182,7 @@ def get_token_object(self, host: str = None, fail_if_no_token: bool = False, **k
self._token_cache[host] = tokens
self._store_cache_file()

self._print_accessing_as(good_user)
return good_token

def get_token(self, host: str = None, fail_if_no_token: bool = False, **kwargs) -> str:
Expand Down Expand Up @@ -202,9 +212,9 @@ def _is_expired(token: Dict[str, str]) -> bool:
return is_expired

@staticmethod
def is_valid_token(token: Union[str, Auth, DagshubTokenABC], host: str) -> bool:
def get_username_of_token(token: Union[str, Auth, DagshubTokenABC], host: str) -> Optional[Dict]:
"""
Check for token validity
Check for token validity and return the dictionary with the info of the user of the token

Args:
token: token to check validity
Expand All @@ -219,13 +229,24 @@ def is_valid_token(token: Union[str, Auth, DagshubTokenABC], host: str) -> bool:
resp = http_request("GET", check_url, auth=auth)

try:
# 500's might be ok since they're server errors, so check only for 400's
assert not (400 <= resp.status_code <= 499)
if resp.status_code == 200:
assert "login" in resp.json()
return True
assert resp.status_code == 200
user = resp.json()
assert "login" in user
assert "username" in user
return user
except AssertionError:
return False
return None

@staticmethod
def is_valid_token(token: Union[str, Auth, DagshubTokenABC], host: str) -> bool:
"""
EvgeniLeonti marked this conversation as resolved.
Show resolved Hide resolved
Check for token validity

Args:
token: token to check validity
host: which host to connect against
"""
return TokenStorage.get_username_of_token(token, host) is not None

def _load_cache_file(self) -> Dict[str, List[DagshubTokenABC]]:
logger.debug(f"Loading token cache from {self.cache_location}")
Expand Down Expand Up @@ -281,6 +302,21 @@ def _store_cache_file(self):
logger.error(f"Error while storing DagsHub token cache: {traceback.format_exc()}")
raise

def _print_accessing_as(self, user: Dict):
"""
This function prints a message to the log that we are accessing as a certain user.
It does this only once per command, to avoid spamming the logs.
It called after successful token validation.

"""
if self._accessing_as_was_printed:
return

username = user["username"]

log_message(f"Accessing as {username}")
self._accessing_as_was_printed = True

def __getstate__(self):
d = self.__dict__
# Don't pickle the lock. This will make it so multiple authenticators might request for tokens at the same time
Expand Down
10 changes: 9 additions & 1 deletion tests/common/test_determine_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
YieldFixture = Generator[T, None, None]


@pytest.fixture
def mock_get_username_of_token(mocker):
return mocker.patch(
"dagshub.auth.tokens.TokenStorage.get_username_of_token",
return_value={"username": "testuser", "login": "testlogin"},
)


@pytest.fixture
def repo_name() -> str:
return f"user/repo-{uuid.uuid4()}"
Expand All @@ -29,7 +37,7 @@ def repo_name() -> str:
"https://somewhere.else:8080/prefix",
]
)
def dagshub_host(request) -> YieldFixture[str]:
def dagshub_host(request, mock_get_username_of_token) -> YieldFixture[str]:
host = request.param
old_value = dagshub.common.config.host
dagshub.common.config.host = host
Expand Down
10 changes: 9 additions & 1 deletion tests/data_engine/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@pytest.fixture
def ds(mocker) -> Datasource:
def ds(mocker, mock_get_username_of_token) -> Datasource:
ds_state = datasources.DatasourceState(id=1, name="test-dataset", repo="kirill/repo")
ds_state.path = "repo://kirill/repo/data/"
mocker.patch.object(ds_state, "client")
Expand Down Expand Up @@ -77,3 +77,11 @@ def query_result(ds, some_datapoints):
fields.append(f)
qr = QueryResult(datasource=ds, _entries=some_datapoints, fields=fields)
return qr


@pytest.fixture
def mock_get_username_of_token(mocker):
mocker.patch(
"dagshub.auth.tokens.TokenStorage.get_username_of_token",
return_value={"username": "testuser", "login": "testlogin"},
)
8 changes: 4 additions & 4 deletions tests/data_engine/test_datasource_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
("repo://user/repo/branch/with/slashes:/", "user", "repo", "branch/with/slashes", "/"),
],
)
def test_repo_regex(in_str, user, repo, revision, prefix):
def test_repo_regex(in_str, user, repo, revision, prefix, mock_get_username_of_token):
ds = DatasourceState(repo="user/repo")
ds.path = in_str
ds.source_type = DatasourceType.REPOSITORY
Expand All @@ -40,7 +40,7 @@ def test_repo_regex(in_str, user, repo, revision, prefix):
@pytest.mark.parametrize(
"in_str", ["s3://user/repo/prefix", "user/repo/", "repo://user/", "repo://" "repo://user/repo/wrong\\branch:"]
)
def test_repo_regex_incorrect(in_str):
def test_repo_regex_incorrect(in_str, mock_get_username_of_token):
ds = DatasourceState(repo="user/repo")
ds.path = in_str
ds.source_type = DatasourceType.REPOSITORY
Expand All @@ -60,7 +60,7 @@ def test_repo_regex_incorrect(in_str):
("s3://bucket_with.weird-chars/longer/prefix", "s3", "bucket_with.weird-chars", "/longer/prefix"),
],
)
def test_bucket_regex(in_str, schema, bucket, prefix):
def test_bucket_regex(in_str, schema, bucket, prefix, mock_get_username_of_token):
ds = DatasourceState(repo="user/repo")
ds.path = in_str
ds.source_type = DatasourceType.BUCKET
Expand All @@ -78,7 +78,7 @@ def test_bucket_regex(in_str, schema, bucket, prefix):
"s3://bucket.www.com/prefix",
],
)
def test_bucket_regex_incorrect(in_str):
def test_bucket_regex_incorrect(in_str, mock_get_username_of_token):
ds = DatasourceState(repo="user/repo")
ds.path = in_str
ds.source_type = DatasourceType.REPOSITORY
Expand Down
16 changes: 14 additions & 2 deletions tests/dda/mock_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from httpx import Response
from respx import MockRouter, Route
from tests.util import valid_token_side_effect


class MockApi(MockRouter):
Expand All @@ -17,15 +18,24 @@ def __init__(self, git_repo, user="user", reponame="repo", *args, **kwargs):
route_dict = {k: (self._endpoints[k], self._responses[k]) for k in self._endpoints}
for route_name in route_dict:
endpoint_regex, return_value = route_dict[route_name]
self.route(name=route_name, url__regex=endpoint_regex).mock(return_value)

# if return value is a response, mock it directly; otherwise, mock the side effect
if isinstance(return_value, Response):
self.route(name=route_name, url__regex=endpoint_regex).mock(return_value)
else:
self.route(name=route_name, url__regex=endpoint_regex).mock(side_effect=return_value)

@property
def repourlpath(self):
return f"{self.user}/{self.reponame}"

@property
def api_prefix(self):
return "/api/v1"

@property
def repoapipath(self):
return f"/api/v1/repos/{self.repourlpath}"
return f"{self.api_prefix}/repos/{self.repourlpath}"

@property
def repophysicalpath(self):
Expand Down Expand Up @@ -64,6 +74,7 @@ def _default_endpoints_and_responses(self):
"branches": rf"{self.repoapipath}/branches/?$",
"list_root": rf"{self.repoapipath}/content/{self.current_revision}/$",
"storages": rf"{self.repoapipath}/storage/?$",
"user": rf"{self.api_prefix}/user",
}

responses = {
Expand Down Expand Up @@ -213,6 +224,7 @@ def _default_endpoints_and_responses(self):
}
],
),
"user": valid_token_side_effect,
}

return endpoints, responses
Expand Down
18 changes: 0 additions & 18 deletions tests/dda/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,12 @@
)


def valid_token_side_effect(request: httpx.Request) -> httpx.Response:
if request.headers["Authorization"] == "Bearer good-token":
return httpx.Response(
200,
json={
"id": 1,
"login": "user",
"full_name": "user",
"avatar_url": "random_url",
"username": "user",
},
)
else:
return httpx.Response(401)


@pytest.fixture
def token_api(mock_api):
# Disable the env var token for these tests explicitly
old_token_val = dagshub.common.config.token
dagshub.common.config.token = None

mock_api.get("https://dagshub.com/api/v1/user").mock(side_effect=valid_token_side_effect)

mock_api.post("https://dagshub.com/api/v1/middleman").mock(httpx.Response(200, json="code"))

a_day_away = datetime.datetime.utcnow() + datetime.timedelta(days=1)
Expand Down
20 changes: 20 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import os
import httpx


@contextlib.contextmanager
Expand All @@ -9,3 +10,22 @@ def remember_cwd():
yield
finally:
os.chdir(curdir)


def valid_token_side_effect(request: httpx.Request) -> httpx.Response:
valid_auth_tokens = ["good-token", "token-set-in-env-var", "token"]

auth_header = request.headers.get("Authorization")
if auth_header and auth_header.split(" ")[1] in valid_auth_tokens:
return httpx.Response(
200,
json={
"id": 1,
"login": "user",
"full_name": "user",
"avatar_url": "random_url",
"username": "user",
},
)
else:
return httpx.Response(401)
Loading