diff --git a/general-superstaq/general_superstaq/service.py b/general-superstaq/general_superstaq/service.py index 3e60b8c35..55f100489 100644 --- a/general-superstaq/general_superstaq/service.py +++ b/general-superstaq/general_superstaq/service.py @@ -171,11 +171,29 @@ def get_user_info(self, *, name: str) -> list[dict[str, str | float]]: ... @overload def get_user_info(self, *, email: str) -> list[dict[str, str | float]]: ... + @overload + def get_user_info(self, *, user_id: int) -> list[dict[str, str | float]]: ... + + @overload + def get_user_info(self, *, name: str, user_id: int) -> list[dict[str, str | float]]: ... + + @overload + def get_user_info(self, *, email: str, user_id: int) -> list[dict[str, str | float]]: ... + @overload def get_user_info(self, *, name: str, email: str) -> list[dict[str, str | float]]: ... + @overload def get_user_info( - self, *, name: str | None = None, email: str | None = None + self, *, name: str, email: str, user_id: int + ) -> list[dict[str, str | float]]: ... + + def get_user_info( + self, + *, + name: str | None = None, + email: str | None = None, + user_id: int | None = None, ) -> dict[str, str | float] | list[dict[str, str | float]]: """Gets a dictionary of the user's info. @@ -186,16 +204,17 @@ def get_user_info( Args: name: A name to search by. Defaults to None. - email: An email address to search by. Defaults to None + email: An email address to search by. Defaults to None. + user_id: A user ID to search by. Defaults to None. Returns: A dictionary of the user information. In the case that either the name or email query kwarg is used, a list of dictionaries is returned, corresponding to the user information for each user that matches the query. """ - user_info = self._client.get_user_info(name=name, email=email) + user_info = self._client.get_user_info(name=name, email=email, user_id=user_id) - if name is None and email is None: + if name is None and email is None and user_id is None: # If no query then return the only element in the list. return user_info[0] diff --git a/general-superstaq/general_superstaq/service_test.py b/general-superstaq/general_superstaq/service_test.py index be0ca35d4..0f0ac05e0 100644 --- a/general-superstaq/general_superstaq/service_test.py +++ b/general-superstaq/general_superstaq/service_test.py @@ -95,7 +95,7 @@ def test_get_user_info(mock_get_request: mock.MagicMock) -> None: "role": "free_trial", "balance": 30.0, } - mock_get_request.assert_called_once_with("/get_user_info", query={}) + mock_get_request.assert_called_once_with("/user_info", query={}) @mock.patch( @@ -120,7 +120,7 @@ def test_get_user_info_name_query(mock_get_request: mock.MagicMock) -> None: "balance": 30.0, } ] - mock_get_request.assert_called_once_with("/get_user_info", query={"name": "Alice"}) + mock_get_request.assert_called_once_with("/user_info", query={"name": "Alice"}) @mock.patch( @@ -145,7 +145,7 @@ def test_get_user_info_email_query(mock_get_request: mock.MagicMock) -> None: "balance": 30.0, } ] - mock_get_request.assert_called_once_with("/get_user_info", query={"email": "example@email.com"}) + mock_get_request.assert_called_once_with("/user_info", query={"email": "example@email.com"}) @mock.patch( diff --git a/general-superstaq/general_superstaq/superstaq_client.py b/general-superstaq/general_superstaq/superstaq_client.py index 2ab045bee..936337210 100644 --- a/general-superstaq/general_superstaq/superstaq_client.py +++ b/general-superstaq/general_superstaq/superstaq_client.py @@ -238,7 +238,7 @@ def get_balance(self) -> dict[str, float]: return self.get_request("/balance") def get_user_info( - self, name: str | None = None, email: str | None = None + self, name: str | None = None, email: str | None = None, user_id: int | None = None ) -> list[dict[str, str | float]]: """Gets a dictionary of the user's info. @@ -249,7 +249,8 @@ def get_user_info( Args: name: A name to search by. Defaults to None. - email: An email address to search by. Defaults to None + email: An email address to search by. Defaults to None. + user_id: A user ID to search by. Defaults to None. Returns: A list of dictionaries corresponding to the user @@ -264,7 +265,9 @@ def get_user_info( query["name"] = name if email is not None: query["email"] = email - user_info = self.get_request("/get_user_info", query=query) + if user_id is not None: + query["id"] = str(user_id) + user_info = self.get_request("/user_info", query=query) if not user_info: # Catch empty server response. This shouldn't happen as the server should return # an error code if something is wrong with the request. @@ -704,7 +707,8 @@ def get_request(self, endpoint: str, query: Mapping[str, object] | None = None) Args: endpoint: The endpoint to perform the GET request on. - query: An optional query json to include in the get request. + query: An optional query dictionary to include in the get request. + This query will be appended to the url. Returns: The response of the GET request. @@ -716,11 +720,14 @@ def request() -> requests.Response: Returns: The Flask GET request object. """ + if not query: + q_string = "" + else: + q_string = "?" + urllib.parse.urlencode(query) return self.session.get( - f"{self.url}{endpoint}", + f"{self.url}{endpoint}{q_string}", headers=self.headers, verify=self.verify_https, - json=query, ) response = self._make_request(request) diff --git a/general-superstaq/general_superstaq/superstaq_client_test.py b/general-superstaq/general_superstaq/superstaq_client_test.py index 3dbb9a794..403fec073 100644 --- a/general-superstaq/general_superstaq/superstaq_client_test.py +++ b/general-superstaq/general_superstaq/superstaq_client_test.py @@ -432,7 +432,6 @@ def test_superstaq_client_get_balance(mock_get: mock.MagicMock) -> None: f"http://example.com/{API_VERSION}/balance", headers=EXPECTED_HEADERS, verify=False, - json=None, ) @@ -990,9 +989,8 @@ def test_get_user_info(mock_get: mock.MagicMock) -> None: user_info = client.get_user_info() mock_get.assert_called_once_with( - f"http://example.com/{API_VERSION}/get_user_info", + f"http://example.com/{API_VERSION}/user_info", headers=EXPECTED_HEADERS, - json={}, verify=False, ) assert user_info == [{"Some": "Data"}] @@ -1010,9 +1008,26 @@ def test_get_user_info_query(mock_get: mock.MagicMock) -> None: user_info = client.get_user_info(name="Alice") mock_get.assert_called_once_with( - f"http://example.com/{API_VERSION}/get_user_info", + f"http://example.com/{API_VERSION}/user_info?name=Alice", + headers=EXPECTED_HEADERS, + verify=False, + ) + assert user_info == [{"Some": "Data"}] + + +@mock.patch("requests.Session.get") +def test_get_user_info_query_composite(mock_get: mock.MagicMock) -> None: + client = gss.superstaq_client._SuperstaqClient( + client_name="general-superstaq", + remote_host="http://example.com", + api_key="to_my_heart", + cq_token="cq-token", + ) + mock_get.return_value.json.return_value = {"example@email.com": {"Some": "Data"}} + user_info = client.get_user_info(user_id=42, name="Alice") + mock_get.assert_called_once_with( + f"http://example.com/{API_VERSION}/user_info?name=Alice&id=42", headers=EXPECTED_HEADERS, - json={"name": "Alice"}, verify=False, ) assert user_info == [{"Some": "Data"}] @@ -1035,8 +1050,7 @@ def test_get_user_info_empty_response(mock_get: mock.MagicMock) -> None: client.get_user_info() mock_get.assert_called_once_with( - f"http://example.com/{API_VERSION}/get_user_info", + f"http://example.com/{API_VERSION}/user_info", headers=EXPECTED_HEADERS, - json={}, verify=False, )