diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py index fa0a30e183..6c78154051 100644 --- a/src/openai/_streaming.py +++ b/src/openai/_streaming.py @@ -59,12 +59,7 @@ def __stream__(self) -> Iterator[_T]: if sse.data.startswith("[DONE]"): break - if sse.event is None or ( - sse.event.startswith("response.") or - sse.event.startswith("transcript.") or - sse.event.startswith("image_edit.") or - sse.event.startswith("image_generation.") - ): + if is_valid_event(sse.event): data = sse.json() if is_mapping(data) and data.get("error"): message = None @@ -166,7 +161,7 @@ async def __stream__(self) -> AsyncIterator[_T]: if sse.data.startswith("[DONE]"): break - if sse.event is None or sse.event.startswith("response.") or sse.event.startswith("transcript."): + if is_valid_event(sse.event): data = sse.json() if is_mapping(data) and data.get("error"): message = None @@ -390,6 +385,16 @@ def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[Asy return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream)) +def is_valid_event(event: str | None) -> bool: + """Given an event fieldname, checks if it is a response, transcript, or None""" + VALID_EVENTS = ("response", "transcript", "image_edit", "image_generation") + if event is None: + return True + if event in VALID_EVENTS or any(event.startswith(f"{e}.") for e in VALID_EVENTS): + return True + return False + + def extract_stream_chunk_type( stream_cls: type, *,