diff --git a/google/cloud/storage/client.py b/google/cloud/storage/client.py index 2f56d8719..74b6061c0 100644 --- a/google/cloud/storage/client.py +++ b/google/cloud/storage/client.py @@ -297,6 +297,18 @@ def universe_domain(self): def api_endpoint(self): return self._connection.API_BASE_URL + def update_user_agent(self, user_agent): + """Update the user-agent string for this client. + + :type user_agent: str + :param user_agent: The string to add to the user-agent. + """ + existing_user_agent = self._connection._client_info.user_agent + if existing_user_agent is None: + self._connection.user_agent = user_agent + else: + self._connection.user_agent = f"{user_agent} {existing_user_agent}" + @property def _connection(self): """Get connection or batch on the client. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 99de31961..d5723740e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1528,6 +1528,34 @@ def test_create_bucket_w_conflict_w_user_project(self): _target_object=mock.ANY, ) + def test_update_user_agent_when_default_clientinfo_provided(self): + from google.cloud._http import ClientInfo + + client_info = ClientInfo() + + client = self._make_one(project=None, client_info=client_info) + self.assertGreater(len(client._connection.user_agent), 0) + + client.update_user_agent("my-test-agent/1.0") + self.assertIn("my-test-agent/1.0", client._connection.user_agent) + + def test_update_user_agent_when_none_clientinfo_provided(self): + client = self._make_one(project=None) + client.update_user_agent("my-test-agent/1.0") + + self.assertIn("my-test-agent/1.0", client._connection.user_agent) + + def test_update_user_agent_with_existing_user_agent(self): + from google.cloud._http import ClientInfo + + client_info = ClientInfo(user_agent="existing-agent/2.0") + client = self._make_one(project=None, client_info=client_info) + client.update_user_agent("my-test-agent/1.0") + + self.assertIn( + "my-test-agent/1.0 existing-agent/2.0", client._connection.user_agent + ) + @mock.patch("warnings.warn") def test_create_bucket_w_requester_pays_deprecated(self, mock_warn): from google.cloud.storage.bucket import Bucket