diff --git a/dagshub/auth/tokens.py b/dagshub/auth/tokens.py index 63b6cecd..4c17999d 100644 --- a/dagshub/auth/tokens.py +++ b/dagshub/auth/tokens.py @@ -18,7 +18,7 @@ DagshubAuthenticator, ) from dagshub.common import config -from dagshub.common.helpers import http_request +from dagshub.common.helpers import http_request, log_message from dagshub.common.util import multi_urljoin logger = logging.getLogger(__name__) @@ -43,6 +43,7 @@ def __init__(self, cache_location: str = None, **kwargs): self._known_good_tokens: Dict[str, Set[DagshubTokenABC]] = {} self.__token_access_lock = threading.RLock() + self._accessing_as_was_printed = False @property def _token_cache(self): @@ -126,6 +127,11 @@ def get_token_object(self, host: str = None, fail_if_no_token: bool = False, **k host = host or config.host if host == config.host and config.token is not None: + user = TokenStorage.get_username_of_token(config.token, host) + if user is not None: + self._print_accessing_as(user) + else: + raise RuntimeError("Provided DagsHub token is not valid") return EnvVarDagshubToken(config.token, host) with self._token_access_lock: @@ -135,6 +141,7 @@ def get_token_object(self, host: str = None, fail_if_no_token: bool = False, **k self._known_good_tokens[host] = set() good_token_set = self._known_good_tokens[host] good_token = None + good_user = None token_queue = list(sorted(tokens, key=lambda t: t.priority)) for token in token_queue: @@ -146,9 +153,11 @@ def get_token_object(self, host: str = None, fail_if_no_token: bool = False, **k good_token = token break # Check token validity - elif self.is_valid_token(token, host): + user = TokenStorage.get_username_of_token(token, host) + if user is not None: good_token = token good_token_set.add(token) + good_user = user # Remove invalid token from the cache else: self.invalidate_token(token, host) @@ -173,6 +182,7 @@ def get_token_object(self, host: str = None, fail_if_no_token: bool = False, **k self._token_cache[host] = tokens self._store_cache_file() + self._print_accessing_as(good_user) return good_token def get_token(self, host: str = None, fail_if_no_token: bool = False, **kwargs) -> str: @@ -202,9 +212,9 @@ def _is_expired(token: Dict[str, str]) -> bool: return is_expired @staticmethod - def is_valid_token(token: Union[str, Auth, DagshubTokenABC], host: str) -> bool: + def get_username_of_token(token: Union[str, Auth, DagshubTokenABC], host: str) -> Optional[Dict]: """ - Check for token validity + Check for token validity and return the dictionary with the info of the user of the token Args: token: token to check validity @@ -219,13 +229,24 @@ def is_valid_token(token: Union[str, Auth, DagshubTokenABC], host: str) -> bool: resp = http_request("GET", check_url, auth=auth) try: - # 500's might be ok since they're server errors, so check only for 400's - assert not (400 <= resp.status_code <= 499) - if resp.status_code == 200: - assert "login" in resp.json() - return True + assert resp.status_code == 200 + user = resp.json() + assert "login" in user + assert "username" in user + return user except AssertionError: - return False + return None + + @staticmethod + def is_valid_token(token: Union[str, Auth, DagshubTokenABC], host: str) -> bool: + """ + Check for token validity + + Args: + token: token to check validity + host: which host to connect against + """ + return TokenStorage.get_username_of_token(token, host) is not None def _load_cache_file(self) -> Dict[str, List[DagshubTokenABC]]: logger.debug(f"Loading token cache from {self.cache_location}") @@ -281,6 +302,21 @@ def _store_cache_file(self): logger.error(f"Error while storing DagsHub token cache: {traceback.format_exc()}") raise + def _print_accessing_as(self, user: Dict): + """ + This function prints a message to the log that we are accessing as a certain user. + It does this only once per command, to avoid spamming the logs. + It called after successful token validation. + + """ + if self._accessing_as_was_printed: + return + + username = user["username"] + + log_message(f"Accessing as {username}") + self._accessing_as_was_printed = True + def __getstate__(self): d = self.__dict__ # Don't pickle the lock. This will make it so multiple authenticators might request for tokens at the same time diff --git a/tests/common/test_determine_repo.py b/tests/common/test_determine_repo.py index 2a68ca98..2b0a9917 100644 --- a/tests/common/test_determine_repo.py +++ b/tests/common/test_determine_repo.py @@ -16,6 +16,14 @@ YieldFixture = Generator[T, None, None] +@pytest.fixture +def mock_get_username_of_token(mocker): + return mocker.patch( + "dagshub.auth.tokens.TokenStorage.get_username_of_token", + return_value={"username": "testuser", "login": "testlogin"}, + ) + + @pytest.fixture def repo_name() -> str: return f"user/repo-{uuid.uuid4()}" @@ -29,7 +37,7 @@ def repo_name() -> str: "https://somewhere.else:8080/prefix", ] ) -def dagshub_host(request) -> YieldFixture[str]: +def dagshub_host(request, mock_get_username_of_token) -> YieldFixture[str]: host = request.param old_value = dagshub.common.config.host dagshub.common.config.host = host diff --git a/tests/data_engine/conftest.py b/tests/data_engine/conftest.py index 3fe1536f..8b7619e3 100644 --- a/tests/data_engine/conftest.py +++ b/tests/data_engine/conftest.py @@ -13,7 +13,7 @@ @pytest.fixture -def ds(mocker) -> Datasource: +def ds(mocker, mock_get_username_of_token) -> Datasource: ds_state = datasources.DatasourceState(id=1, name="test-dataset", repo="kirill/repo") ds_state.path = "repo://kirill/repo/data/" mocker.patch.object(ds_state, "client") @@ -77,3 +77,11 @@ def query_result(ds, some_datapoints): fields.append(f) qr = QueryResult(datasource=ds, _entries=some_datapoints, fields=fields) return qr + + +@pytest.fixture +def mock_get_username_of_token(mocker): + mocker.patch( + "dagshub.auth.tokens.TokenStorage.get_username_of_token", + return_value={"username": "testuser", "login": "testlogin"}, + ) diff --git a/tests/data_engine/test_datasource_state.py b/tests/data_engine/test_datasource_state.py index 581c736e..47c7f26d 100644 --- a/tests/data_engine/test_datasource_state.py +++ b/tests/data_engine/test_datasource_state.py @@ -22,7 +22,7 @@ ("repo://user/repo/branch/with/slashes:/", "user", "repo", "branch/with/slashes", "/"), ], ) -def test_repo_regex(in_str, user, repo, revision, prefix): +def test_repo_regex(in_str, user, repo, revision, prefix, mock_get_username_of_token): ds = DatasourceState(repo="user/repo") ds.path = in_str ds.source_type = DatasourceType.REPOSITORY @@ -40,7 +40,7 @@ def test_repo_regex(in_str, user, repo, revision, prefix): @pytest.mark.parametrize( "in_str", ["s3://user/repo/prefix", "user/repo/", "repo://user/", "repo://" "repo://user/repo/wrong\\branch:"] ) -def test_repo_regex_incorrect(in_str): +def test_repo_regex_incorrect(in_str, mock_get_username_of_token): ds = DatasourceState(repo="user/repo") ds.path = in_str ds.source_type = DatasourceType.REPOSITORY @@ -60,7 +60,7 @@ def test_repo_regex_incorrect(in_str): ("s3://bucket_with.weird-chars/longer/prefix", "s3", "bucket_with.weird-chars", "/longer/prefix"), ], ) -def test_bucket_regex(in_str, schema, bucket, prefix): +def test_bucket_regex(in_str, schema, bucket, prefix, mock_get_username_of_token): ds = DatasourceState(repo="user/repo") ds.path = in_str ds.source_type = DatasourceType.BUCKET @@ -78,7 +78,7 @@ def test_bucket_regex(in_str, schema, bucket, prefix): "s3://bucket.www.com/prefix", ], ) -def test_bucket_regex_incorrect(in_str): +def test_bucket_regex_incorrect(in_str, mock_get_username_of_token): ds = DatasourceState(repo="user/repo") ds.path = in_str ds.source_type = DatasourceType.REPOSITORY diff --git a/tests/dda/mock_api.py b/tests/dda/mock_api.py index 336f6138..2b5dafa7 100644 --- a/tests/dda/mock_api.py +++ b/tests/dda/mock_api.py @@ -2,6 +2,7 @@ from httpx import Response from respx import MockRouter, Route +from tests.util import valid_token_side_effect class MockApi(MockRouter): @@ -17,15 +18,24 @@ def __init__(self, git_repo, user="user", reponame="repo", *args, **kwargs): route_dict = {k: (self._endpoints[k], self._responses[k]) for k in self._endpoints} for route_name in route_dict: endpoint_regex, return_value = route_dict[route_name] - self.route(name=route_name, url__regex=endpoint_regex).mock(return_value) + + # if return value is a response, mock it directly; otherwise, mock the side effect + if isinstance(return_value, Response): + self.route(name=route_name, url__regex=endpoint_regex).mock(return_value) + else: + self.route(name=route_name, url__regex=endpoint_regex).mock(side_effect=return_value) @property def repourlpath(self): return f"{self.user}/{self.reponame}" + @property + def api_prefix(self): + return "/api/v1" + @property def repoapipath(self): - return f"/api/v1/repos/{self.repourlpath}" + return f"{self.api_prefix}/repos/{self.repourlpath}" @property def repophysicalpath(self): @@ -64,6 +74,7 @@ def _default_endpoints_and_responses(self): "branches": rf"{self.repoapipath}/branches/?$", "list_root": rf"{self.repoapipath}/content/{self.current_revision}/$", "storages": rf"{self.repoapipath}/storage/?$", + "user": rf"{self.api_prefix}/user", } responses = { @@ -213,6 +224,7 @@ def _default_endpoints_and_responses(self): } ], ), + "user": valid_token_side_effect, } return endpoints, responses diff --git a/tests/dda/test_tokens.py b/tests/dda/test_tokens.py index 6c91553d..6f772175 100644 --- a/tests/dda/test_tokens.py +++ b/tests/dda/test_tokens.py @@ -12,30 +12,12 @@ ) -def valid_token_side_effect(request: httpx.Request) -> httpx.Response: - if request.headers["Authorization"] == "Bearer good-token": - return httpx.Response( - 200, - json={ - "id": 1, - "login": "user", - "full_name": "user", - "avatar_url": "random_url", - "username": "user", - }, - ) - else: - return httpx.Response(401) - - @pytest.fixture def token_api(mock_api): # Disable the env var token for these tests explicitly old_token_val = dagshub.common.config.token dagshub.common.config.token = None - mock_api.get("https://dagshub.com/api/v1/user").mock(side_effect=valid_token_side_effect) - mock_api.post("https://dagshub.com/api/v1/middleman").mock(httpx.Response(200, json="code")) a_day_away = datetime.datetime.utcnow() + datetime.timedelta(days=1) diff --git a/tests/util.py b/tests/util.py index 5516e2c8..c6e98cfb 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,5 +1,6 @@ import contextlib import os +import httpx @contextlib.contextmanager @@ -9,3 +10,22 @@ def remember_cwd(): yield finally: os.chdir(curdir) + + +def valid_token_side_effect(request: httpx.Request) -> httpx.Response: + valid_auth_tokens = ["good-token", "token-set-in-env-var", "token"] + + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.split(" ")[1] in valid_auth_tokens: + return httpx.Response( + 200, + json={ + "id": 1, + "login": "user", + "full_name": "user", + "avatar_url": "random_url", + "username": "user", + }, + ) + else: + return httpx.Response(401)