|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import tempfile |
4 | 5 | from http import HTTPStatus |
5 | 6 | from io import StringIO |
6 | 7 | from typing import Awaitable, Callable, List, Optional |
@@ -51,6 +52,13 @@ def parse_args(): |
51 | 52 | help="The path or url to a single output file. Currently supports " |
52 | 53 | "local file paths, or web (http or https) urls. If a URL is specified," |
53 | 54 | " the file should be available via HTTP PUT.") |
| 55 | + parser.add_argument( |
| 56 | + "--output-tmp-dir", |
| 57 | + type=str, |
| 58 | + default=None, |
| 59 | + help="The directory to store the output file before uploading it " |
| 60 | + "to the output URL.", |
| 61 | + ) |
54 | 62 | parser.add_argument("--response-role", |
55 | 63 | type=nullable_str, |
56 | 64 | default="assistant", |
@@ -134,17 +142,107 @@ async def read_file(path_or_url: str) -> str: |
134 | 142 | return f.read() |
135 | 143 |
|
136 | 144 |
|
137 | | -async def write_file(path_or_url: str, data: str) -> None: |
| 145 | +async def write_local_file(output_path: str, |
| 146 | + batch_outputs: List[BatchRequestOutput]) -> None: |
| 147 | + """ |
| 148 | + Write the responses to a local file. |
| 149 | + output_path: The path to write the responses to. |
| 150 | + batch_outputs: The list of batch outputs to write. |
| 151 | + """ |
| 152 | + # We should make this async, but as long as run_batch runs as a |
| 153 | + # standalone program, blocking the event loop won't effect performance. |
| 154 | + with open(output_path, "w", encoding="utf-8") as f: |
| 155 | + for o in batch_outputs: |
| 156 | + print(o.model_dump_json(), file=f) |
| 157 | + |
| 158 | + |
| 159 | +async def upload_data(output_url: str, data_or_file: str, |
| 160 | + from_file: bool) -> None: |
| 161 | + """ |
| 162 | + Upload a local file to a URL. |
| 163 | + output_url: The URL to upload the file to. |
| 164 | + data_or_file: Either the data to upload or the path to the file to upload. |
| 165 | + from_file: If True, data_or_file is the path to the file to upload. |
| 166 | + """ |
| 167 | + # Timeout is a common issue when uploading large files. |
| 168 | + # We retry max_retries times before giving up. |
| 169 | + max_retries = 5 |
| 170 | + # Number of seconds to wait before retrying. |
| 171 | + delay = 5 |
| 172 | + |
| 173 | + for attempt in range(1, max_retries + 1): |
| 174 | + try: |
| 175 | + # We increase the timeout to 1000 seconds to allow |
| 176 | + # for large files (default is 300). |
| 177 | + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( |
| 178 | + total=1000)) as session: |
| 179 | + if from_file: |
| 180 | + with open(data_or_file, "rb") as file: |
| 181 | + async with session.put(output_url, |
| 182 | + data=file) as response: |
| 183 | + if response.status != 200: |
| 184 | + raise Exception(f"Failed to upload file.\n" |
| 185 | + f"Status: {response.status}\n" |
| 186 | + f"Response: {response.text()}") |
| 187 | + else: |
| 188 | + async with session.put(output_url, |
| 189 | + data=data_or_file) as response: |
| 190 | + if response.status != 200: |
| 191 | + raise Exception(f"Failed to upload data.\n" |
| 192 | + f"Status: {response.status}\n" |
| 193 | + f"Response: {response.text()}") |
| 194 | + |
| 195 | + except Exception as e: |
| 196 | + if attempt < max_retries: |
| 197 | + logger.error( |
| 198 | + f"Failed to upload data (attempt {attempt}). " |
| 199 | + f"Error message: {str(e)}.\nRetrying in {delay} seconds..." |
| 200 | + ) |
| 201 | + await asyncio.sleep(delay) |
| 202 | + else: |
| 203 | + raise Exception(f"Failed to upload data (attempt {attempt}). " |
| 204 | + f"Error message: {str(e)}.") from e |
| 205 | + |
| 206 | + |
| 207 | +async def write_file(path_or_url: str, batch_outputs: List[BatchRequestOutput], |
| 208 | + output_tmp_dir: str) -> None: |
| 209 | + """ |
| 210 | + Write batch_outputs to a file or upload to a URL. |
| 211 | + path_or_url: The path or URL to write batch_outputs to. |
| 212 | + batch_outputs: The list of batch outputs to write. |
| 213 | + output_tmp_dir: The directory to store the output file before uploading it |
| 214 | + to the output URL. |
| 215 | + """ |
138 | 216 | if path_or_url.startswith("http://") or path_or_url.startswith("https://"): |
139 | | - async with aiohttp.ClientSession() as session, \ |
140 | | - session.put(path_or_url, data=data.encode("utf-8")): |
141 | | - pass |
| 217 | + if output_tmp_dir is None: |
| 218 | + logger.info("Writing outputs to memory buffer") |
| 219 | + output_buffer = StringIO() |
| 220 | + for o in batch_outputs: |
| 221 | + print(o.model_dump_json(), file=output_buffer) |
| 222 | + output_buffer.seek(0) |
| 223 | + logger.info("Uploading outputs to %s", path_or_url) |
| 224 | + await upload_data( |
| 225 | + path_or_url, |
| 226 | + output_buffer.read().strip().encode("utf-8"), |
| 227 | + from_file=False, |
| 228 | + ) |
| 229 | + else: |
| 230 | + # Write responses to a temporary file and then upload it to the URL. |
| 231 | + with tempfile.NamedTemporaryFile( |
| 232 | + mode="w", |
| 233 | + encoding="utf-8", |
| 234 | + dir=output_tmp_dir, |
| 235 | + prefix="tmp_batch_output_", |
| 236 | + suffix=".jsonl", |
| 237 | + ) as f: |
| 238 | + logger.info("Writing outputs to temporary local file %s", |
| 239 | + f.name) |
| 240 | + await write_local_file(f.name, batch_outputs) |
| 241 | + logger.info("Uploading outputs to %s", path_or_url) |
| 242 | + await upload_data(path_or_url, f.name, from_file=True) |
142 | 243 | else: |
143 | | - # We should make this async, but as long as this is always run as a |
144 | | - # standalone program, blocking the event loop won't effect performance |
145 | | - # in this particular case. |
146 | | - with open(path_or_url, "w", encoding="utf-8") as f: |
147 | | - f.write(data) |
| 244 | + logger.info("Writing outputs to local file %s", path_or_url) |
| 245 | + await write_local_file(path_or_url, batch_outputs) |
148 | 246 |
|
149 | 247 |
|
150 | 248 | def make_error_request_output(request: BatchRequestInput, |
@@ -317,12 +415,7 @@ async def main(args): |
317 | 415 | with tracker.pbar(): |
318 | 416 | responses = await asyncio.gather(*response_futures) |
319 | 417 |
|
320 | | - output_buffer = StringIO() |
321 | | - for response in responses: |
322 | | - print(response.model_dump_json(), file=output_buffer) |
323 | | - |
324 | | - output_buffer.seek(0) |
325 | | - await write_file(args.output_file, output_buffer.read().strip()) |
| 418 | + await write_file(args.output_file, responses, args.output_tmp_dir) |
326 | 419 |
|
327 | 420 |
|
328 | 421 | if __name__ == "__main__": |
|
0 commit comments