|
28 | 28 | from typing import Union |
29 | 29 |
|
30 | 30 | from google.genai import types |
| 31 | +from google.genai.errors import ClientError |
31 | 32 | from typing_extensions import override |
32 | 33 |
|
33 | 34 | from .. import version |
|
51 | 52 | _AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine' |
52 | 53 | _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID' |
53 | 54 |
|
| 55 | +_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """ |
| 56 | +On how to mitigate this issue, please refer to: |
| 57 | +
|
| 58 | +https://google.github.io/adk-docs/agents/models/#error-code-429-resource_exhausted |
| 59 | +""" |
| 60 | + |
| 61 | + |
| 62 | +class _ResourceExhaustedError(ClientError): |
| 63 | + """Represents an resources exhausted error received from the Model.""" |
| 64 | + |
| 65 | + def __init__( |
| 66 | + self, |
| 67 | + client_error: ClientError, |
| 68 | + ): |
| 69 | + super().__init__( |
| 70 | + code=client_error.code, |
| 71 | + response_json=client_error.details, |
| 72 | + response=client_error.response, |
| 73 | + ) |
| 74 | + |
| 75 | + def __str__(self): |
| 76 | + # We don't get override the actual message on ClientError, so we override |
| 77 | + # this method instead. This will ensure that when the exception is |
| 78 | + # stringified (for either publishing the exception on console or to logs) |
| 79 | + # we put in the required details for the developer. |
| 80 | + base_message = super().__str__() |
| 81 | + return f'{_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE}\n\n{base_message}' |
| 82 | + |
54 | 83 |
|
55 | 84 | class Gemini(BaseLlm): |
56 | 85 | """Integration for Gemini models. |
@@ -149,50 +178,61 @@ async def generate_content_async( |
149 | 178 | llm_request.config.http_options.headers |
150 | 179 | ) |
151 | 180 |
|
152 | | - if stream: |
153 | | - responses = await self.api_client.aio.models.generate_content_stream( |
154 | | - model=llm_request.model, |
155 | | - contents=llm_request.contents, |
156 | | - config=llm_request.config, |
157 | | - ) |
| 181 | + try: |
| 182 | + if stream: |
| 183 | + responses = await self.api_client.aio.models.generate_content_stream( |
| 184 | + model=llm_request.model, |
| 185 | + contents=llm_request.contents, |
| 186 | + config=llm_request.config, |
| 187 | + ) |
| 188 | + |
| 189 | + # for sse, similar as bidi (see receive method in |
| 190 | + # gemini_llm_connection.py), we need to mark those text content as |
| 191 | + # partial and after all partial contents are sent, we send an |
| 192 | + # accumulated event which contains all the previous partial content. The |
| 193 | + # only difference is bidi rely on complete_turn flag to detect end while |
| 194 | + # sse depends on finish_reason. |
| 195 | + aggregator = StreamingResponseAggregator() |
| 196 | + async with Aclosing(responses) as agen: |
| 197 | + async for response in agen: |
| 198 | + logger.debug(_build_response_log(response)) |
| 199 | + async with Aclosing( |
| 200 | + aggregator.process_response(response) |
| 201 | + ) as aggregator_gen: |
| 202 | + async for llm_response in aggregator_gen: |
| 203 | + yield llm_response |
| 204 | + if (close_result := aggregator.close()) is not None: |
| 205 | + # Populate cache metadata in the final aggregated response for |
| 206 | + # streaming |
| 207 | + if cache_metadata: |
| 208 | + cache_manager.populate_cache_metadata_in_response( |
| 209 | + close_result, cache_metadata |
| 210 | + ) |
| 211 | + yield close_result |
| 212 | + |
| 213 | + else: |
| 214 | + response = await self.api_client.aio.models.generate_content( |
| 215 | + model=llm_request.model, |
| 216 | + contents=llm_request.contents, |
| 217 | + config=llm_request.config, |
| 218 | + ) |
| 219 | + logger.info('Response received from the model.') |
| 220 | + logger.debug(_build_response_log(response)) |
158 | 221 |
|
159 | | - # for sse, similar as bidi (see receive method in gemini_llm_connection.py), |
160 | | - # we need to mark those text content as partial and after all partial |
161 | | - # contents are sent, we send an accumulated event which contains all the |
162 | | - # previous partial content. The only difference is bidi rely on |
163 | | - # complete_turn flag to detect end while sse depends on finish_reason. |
164 | | - aggregator = StreamingResponseAggregator() |
165 | | - async with Aclosing(responses) as agen: |
166 | | - async for response in agen: |
167 | | - logger.debug(_build_response_log(response)) |
168 | | - async with Aclosing( |
169 | | - aggregator.process_response(response) |
170 | | - ) as aggregator_gen: |
171 | | - async for llm_response in aggregator_gen: |
172 | | - yield llm_response |
173 | | - if (close_result := aggregator.close()) is not None: |
174 | | - # Populate cache metadata in the final aggregated response for streaming |
| 222 | + llm_response = LlmResponse.create(response) |
175 | 223 | if cache_metadata: |
176 | 224 | cache_manager.populate_cache_metadata_in_response( |
177 | | - close_result, cache_metadata |
| 225 | + llm_response, cache_metadata |
178 | 226 | ) |
179 | | - yield close_result |
180 | | - |
181 | | - else: |
182 | | - response = await self.api_client.aio.models.generate_content( |
183 | | - model=llm_request.model, |
184 | | - contents=llm_request.contents, |
185 | | - config=llm_request.config, |
186 | | - ) |
187 | | - logger.info('Response received from the model.') |
188 | | - logger.debug(_build_response_log(response)) |
189 | | - |
190 | | - llm_response = LlmResponse.create(response) |
191 | | - if cache_metadata: |
192 | | - cache_manager.populate_cache_metadata_in_response( |
193 | | - llm_response, cache_metadata |
194 | | - ) |
195 | | - yield llm_response |
| 227 | + yield llm_response |
| 228 | + except ClientError as ce: |
| 229 | + if ce.code == 429: |
| 230 | + # We expect running into a Resource Exhausted error to be a common |
| 231 | + # client error that developers would run into. We enhance the messaging |
| 232 | + # with possible fixes to this issue. |
| 233 | + raise _ResourceExhaustedError(ce) from ce |
| 234 | + |
| 235 | + raise ce |
196 | 236 |
|
197 | 237 | @cached_property |
198 | 238 | def api_client(self) -> Client: |
|
0 commit comments