|  | 
| 9 | 9 | 
 | 
| 10 | 10 | import httpx | 
| 11 | 11 | 
 | 
| 12 |  | -from ._utils import extract_type_var_from_base | 
|  | 12 | +from ._utils import is_mapping, extract_type_var_from_base | 
|  | 13 | +from ._exceptions import APIError | 
| 13 | 14 | 
 | 
| 14 | 15 | if TYPE_CHECKING: | 
| 15 | 16 |     from ._client import GradientAI, AsyncGradientAI | 
| @@ -55,7 +56,25 @@ def __stream__(self) -> Iterator[_T]: | 
| 55 | 56 |         iterator = self._iter_events() | 
| 56 | 57 | 
 | 
| 57 | 58 |         for sse in iterator: | 
| 58 |  | -            yield process_data(data=sse.json(), cast_to=cast_to, response=response) | 
|  | 59 | +            if sse.data.startswith("[DONE]"): | 
|  | 60 | +                break | 
|  | 61 | + | 
|  | 62 | +            data = sse.json() | 
|  | 63 | +            if is_mapping(data) and data.get("error"): | 
|  | 64 | +                message = None | 
|  | 65 | +                error = data.get("error") | 
|  | 66 | +                if is_mapping(error): | 
|  | 67 | +                    message = error.get("message") | 
|  | 68 | +                if not message or not isinstance(message, str): | 
|  | 69 | +                    message = "An error occurred during streaming" | 
|  | 70 | + | 
|  | 71 | +                raise APIError( | 
|  | 72 | +                    message=message, | 
|  | 73 | +                    request=self.response.request, | 
|  | 74 | +                    body=data["error"], | 
|  | 75 | +                ) | 
|  | 76 | + | 
|  | 77 | +            yield process_data(data=data, cast_to=cast_to, response=response) | 
| 59 | 78 | 
 | 
| 60 | 79 |         # Ensure the entire stream is consumed | 
| 61 | 80 |         for _sse in iterator: | 
| @@ -119,7 +138,25 @@ async def __stream__(self) -> AsyncIterator[_T]: | 
| 119 | 138 |         iterator = self._iter_events() | 
| 120 | 139 | 
 | 
| 121 | 140 |         async for sse in iterator: | 
| 122 |  | -            yield process_data(data=sse.json(), cast_to=cast_to, response=response) | 
|  | 141 | +            if sse.data.startswith("[DONE]"): | 
|  | 142 | +                break | 
|  | 143 | + | 
|  | 144 | +            data = sse.json() | 
|  | 145 | +            if is_mapping(data) and data.get("error"): | 
|  | 146 | +                message = None | 
|  | 147 | +                error = data.get("error") | 
|  | 148 | +                if is_mapping(error): | 
|  | 149 | +                    message = error.get("message") | 
|  | 150 | +                if not message or not isinstance(message, str): | 
|  | 151 | +                    message = "An error occurred during streaming" | 
|  | 152 | + | 
|  | 153 | +                raise APIError( | 
|  | 154 | +                    message=message, | 
|  | 155 | +                    request=self.response.request, | 
|  | 156 | +                    body=data["error"], | 
|  | 157 | +                ) | 
|  | 158 | + | 
|  | 159 | +            yield process_data(data=data, cast_to=cast_to, response=response) | 
| 123 | 160 | 
 | 
| 124 | 161 |         # Ensure the entire stream is consumed | 
| 125 | 162 |         async for _sse in iterator: | 
|  | 
0 commit comments