-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
343 lines (298 loc) · 11.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
# main.py
import os
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
import whisper
import shutil
from pathlib import Path
import time
import logging
import asyncio
import torch
import gc
import sys
import tempfile
import json
from concurrent.futures import ThreadPoolExecutor
import threading
import queue
import signal
from typing import Dict
# Set up logging to both file and console
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Initialize FastAPI with metadata
app = FastAPI(
title="Whisper Transcription API",
description="API for transcribing audio files using OpenAI's Whisper model",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables
model = None
MAX_CONCURRENT_TRANSCRIPTIONS = 3 # Adjust based on system resources
transcription_executor = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_TRANSCRIPTIONS)
job_queue = queue.Queue()
active_jobs: Dict[str, dict] = {} # Store job status information
job_threads: Dict[str, threading.Thread] = {} # Store job threads for termination
@app.on_event("startup")
async def startup_event():
global model
logger.info("Loading Whisper model...")
model = whisper.load_model("base")
logger.info("Model loaded successfully!")
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down transcription executor...")
transcription_executor.shutdown(wait=False)
# Terminate any running jobs
for job_id in list(job_threads.keys()):
await terminate_job(job_id)
def cleanup_file(file_path: str):
"""Clean up temporary file and force garbage collection"""
try:
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Cleaned up temporary file: {file_path}")
gc.collect() # Force garbage collection
except Exception as e:
logger.error(f"Error cleaning up file {file_path}: {str(e)}")
def process_large_file(file_path: str, job_id: str) -> dict:
"""Process large audio file with memory optimization"""
try:
if not os.path.exists(file_path):
raise Exception(f"Audio file not found: {file_path}")
logger.info(f"Starting transcription for job {job_id}")
active_jobs[job_id]["status"] = "processing"
# Force garbage collection before processing
gc.collect()
# Process the file
result = model.transcribe(
file_path,
verbose=True,
fp16=False,
task='transcribe'
)
logger.info(f"Transcription completed successfully for job {job_id}")
active_jobs[job_id]["status"] = "completed"
active_jobs[job_id]["result"] = {
"text": result["text"],
"segments": result.get("segments", [])
}
logger.info(f"Transcription result for job {job_id}: {result['text'][:200]}...")
return result
except Exception as e:
error_msg = f"Error during transcription: {str(e)}"
logger.error(f"Job {job_id}: {error_msg}")
active_jobs[job_id]["status"] = "failed"
active_jobs[job_id]["error"] = error_msg
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise
finally:
if job_id in job_threads:
del job_threads[job_id]
gc.collect()
@app.get("/", response_class=HTMLResponse)
async def root():
return """
<html>
<head>
<title>Whisper Transcription Service</title>
<style>
body {
font-family: system-ui, -apple-system, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 2rem;
line-height: 1.6;
}
h1 { color: #2563eb; }
.endpoint {
background: #f8fafc;
padding: 1rem;
border-radius: 0.5rem;
margin: 1rem 0;
}
</style>
</head>
<body>
<h1>📝 Whisper Transcription Service</h1>
<p>Optimized for handling large audio files with concurrent processing.</p>
<div class="endpoint">
<h3>Transcribe Audio</h3>
<code>POST /transcribe/</code>
<p>Submit any size audio file for transcription.</p>
</div>
<div class="endpoint">
<h3>Check Job Status</h3>
<code>GET /status/{job_id}</code>
<p>Check the status of a transcription job.</p>
</div>
<div class="endpoint">
<h3>List Active Jobs</h3>
<code>GET /jobs</code>
<p>List all active transcription jobs.</p>
</div>
<div class="endpoint">
<h3>Terminate Job</h3>
<code>DELETE /jobs/{job_id}</code>
<p>Terminate a running transcription job.</p>
</div>
<div class="endpoint">
<h3>Health Check</h3>
<code>GET /health</code>
<p>Check the service status and supported formats.</p>
</div>
</body>
</html>
"""
@app.get("/health")
async def health_check():
"""Check the health status of the service"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return JSONResponse(content={
"status": "healthy",
"model": "whisper-base",
"supported_formats": [".mp3", ".wav", ".m4a", ".ogg", ".flac"],
"max_file_size": "unlimited",
"gpu_available": torch.cuda.is_available(),
"active_jobs": len(active_jobs),
"max_concurrent_jobs": MAX_CONCURRENT_TRANSCRIPTIONS
})
@app.get("/jobs")
async def list_jobs():
"""List all active transcription jobs"""
jobs_list = []
for job_id, info in active_jobs.items():
job_info = {
"job_id": job_id,
"status": info["status"],
"created_at": info["created_at"],
"filename": info["filename"]
}
if info["status"] == "failed" and "error" in info:
job_info["message"] = info["error"]
jobs_list.append(job_info)
return JSONResponse(content={"jobs": jobs_list})
@app.delete("/jobs/{job_id}")
async def terminate_job(job_id: str):
"""Terminate a running transcription job"""
if job_id not in active_jobs:
raise HTTPException(status_code=404, detail="Job not found")
if active_jobs[job_id]["status"] not in ["processing", "queued"]:
raise HTTPException(status_code=400, detail="Job is not running or queued")
# Mark job as terminated
active_jobs[job_id]["status"] = "terminated"
# Mark job as terminated and set flag
active_jobs[job_id]["terminated"] = True
active_jobs[job_id]["status"] = "terminated"
# Remove thread reference if it exists
if job_id in job_threads:
del job_threads[job_id]
return JSONResponse(content={"message": f"Job {job_id} terminated successfully"})
@app.get("/status/{job_id}")
async def get_job_status(job_id: str):
"""Get the status of a transcription job"""
if job_id not in active_jobs:
raise HTTPException(status_code=404, detail="Job not found")
job_info = active_jobs[job_id]
response = {
"job_id": job_id,
"status": job_info["status"],
"created_at": job_info["created_at"],
"filename": job_info["filename"]
}
if job_info["status"] == "failed":
response["message"] = job_info.get("error", "Unknown error occurred")
elif job_info["status"] == "completed":
response["result"] = job_info.get("result", {})
return JSONResponse(content=response)
@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
"""
Transcribe an audio file using the Whisper model.
Optimized for large files with no size limit and concurrent processing.
"""
logger.info(f"Received file: {file.filename} for transcription")
# Validate file extension
file_extension = Path(file.filename).suffix.lower()
valid_extensions = {'.mp3', '.wav', '.m4a', '.ogg', '.flac'}
if file_extension not in valid_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {', '.join(valid_extensions)}"
)
# Create a unique job ID
job_id = f"job_{int(time.time())}_{os.urandom(4).hex()}"
logger.info(f"Created job ID: {job_id}")
temp_file_path = None
try:
# Create a temporary file in our dedicated temp directory
temp_file_path = os.path.join('/app/temp', f'whisper_{job_id}{file_extension}')
with open(temp_file_path, 'wb') as temp_file:
logger.info(f"Created temporary file: {temp_file_path}")
# Initialize job status
active_jobs[job_id] = {
"status": "uploading",
"created_at": time.time(),
"filename": file.filename,
"terminated": False,
"temp_file": temp_file_path
}
# Save uploaded file using chunked transfer
chunk_size = 1024 * 1024 # 1MB chunks
total_size = 0
while chunk := await file.read(chunk_size):
temp_file.write(chunk)
total_size += len(chunk)
file_size = os.path.getsize(temp_file_path)
logger.info(f"File saved successfully. Size: {file_size:,} bytes")
# Submit the job to the thread pool
active_jobs[job_id]["status"] = "queued"
future = transcription_executor.submit(process_large_file, temp_file_path, job_id)
# Store the thread reference
job_threads[job_id] = threading.current_thread()
# Return job ID for status checking
response = {
"job_id": job_id,
"status": "queued",
"message": "Transcription job queued successfully",
"file_info": {
"name": file.filename,
"size": file_size
}
}
logger.info(f"Job created successfully: {json.dumps(response)}")
return JSONResponse(content=response)
except Exception as e:
logger.error(f"Error during transcription: {str(e)}")
# Update job status
if job_id in active_jobs:
active_jobs[job_id]["status"] = "failed"
active_jobs[job_id]["error"] = str(e)
# Clean up in case of error
if temp_file_path and os.path.exists(temp_file_path):
cleanup_file(temp_file_path)
raise HTTPException(
status_code=500,
detail=f"Transcription failed: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")