diff --git a/.github/workflows/ai-runner-docker.yaml b/.github/workflows/ai-runner-docker.yaml index 1bd992a1..b18c6ec3 100644 --- a/.github/workflows/ai-runner-docker.yaml +++ b/.github/workflows/ai-runner-docker.yaml @@ -76,4 +76,4 @@ jobs: file: "Dockerfile" labels: ${{ steps.meta.outputs.labels }} cache-from: type=registry,ref=livepeerci/build:cache - cache-to: type=registry,ref=livepeerci/build:cache,mode=max + cache-to: type=registry,ref=livepeerci/build:cache,mode=max \ No newline at end of file diff --git a/.github/workflows/validate-openapi-on-pr.yaml b/.github/workflows/validate-openapi-on-pr.yaml index 68a495d3..db00ae79 100644 --- a/.github/workflows/validate-openapi-on-pr.yaml +++ b/.github/workflows/validate-openapi-on-pr.yaml @@ -46,4 +46,4 @@ jobs: if ! git diff --exit-code; then echo "::error::Go bindings have changed. Please run 'make' at the root of the repository and commit the changes." exit 1 - fi + fi \ No newline at end of file diff --git a/go.mod b/go.mod index 226ae8fc..f17ce7d0 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21.5 require ( github.com/deepmap/oapi-codegen/v2 v2.2.0 github.com/docker/cli v24.0.5+incompatible - github.com/docker/docker v24.0.7+incompatible + github.com/docker/docker v24.0.9+incompatible github.com/docker/go-connections v0.4.0 github.com/getkin/kin-openapi v0.124.0 github.com/go-chi/chi/v5 v5.0.12 diff --git a/go.sum b/go.sum index 015a9ff7..eacbbeb1 100644 --- a/go.sum +++ b/go.sum @@ -17,8 +17,8 @@ github.com/docker/cli v24.0.5+incompatible h1:WeBimjvS0eKdH4Ygx+ihVq1Q++xg36M/rM github.com/docker/cli v24.0.5+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v24.0.7+incompatible h1:Wo6l37AuwP3JaMnZa226lzVXGA3F9Ig1seQen0cKYlM= -github.com/docker/docker v24.0.7+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v24.0.9+incompatible h1:HPGzNmwfLZWdxHqK9/II92pyi1EpYKsAqcl4G0Of9v0= +github.com/docker/docker v24.0.9+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= diff --git a/runner/app/main.py b/runner/app/main.py index 147a0c21..a2596283 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -1,5 +1,7 @@ import logging import os +import sys +import cv2 from contextlib import asynccontextmanager from app.routes import health @@ -15,8 +17,8 @@ async def lifespan(app: FastAPI): app.include_router(health.router) - pipeline = os.environ["PIPELINE"] - model_id = os.environ["MODEL_ID"] + pipeline = os.environ.get("PIPELINE", "") # Default to + model_id = os.environ.get("MODEL_ID", "") # Provide a default if necessary app.pipeline = load_pipeline(pipeline, model_id) app.include_router(load_route(pipeline)) @@ -44,8 +46,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.audio_to_text import AudioToTextPipeline return AudioToTextPipeline(model_id) - case "frame-interpolation": - raise NotImplementedError("frame-interpolation pipeline not implemented") + case "FILMPipeline": + from app.pipelines.frame_interpolation import FILMPipeline + + return FILMPipeline(model_id) case "upscale": from app.pipelines.upscale import UpscalePipeline @@ -78,8 +82,10 @@ def load_route(pipeline: str) -> any: from app.routes import audio_to_text return audio_to_text.router - case "frame-interpolation": - raise NotImplementedError("frame-interpolation pipeline not implemented") + case "FILMPipeline": + from app.routes import frame_interpolation + + return frame_interpolation.router case "upscale": from app.routes import upscale diff --git a/runner/app/pipelines/frame_interpolation.py b/runner/app/pipelines/frame_interpolation.py index 396c8067..755afee7 100644 --- a/runner/app/pipelines/frame_interpolation.py +++ b/runner/app/pipelines/frame_interpolation.py @@ -1,5 +1,113 @@ -from app.pipelines.base import Pipeline +import torch +from torchvision.transforms import v2 +from tqdm import tqdm +import bisect +import numpy as np +from app.pipelines.utils.utils import get_model_dir -class FrameInterpolationPipeline(Pipeline): - pass +class FILMPipeline: + model: torch.jit.ScriptModule + + def __init__(self, model_id: str): + self.model_id = model_id + model_dir = get_model_dir() # Get the directory where models are stored + model_path = f"{model_dir}/{model_id}" # Construct the full path to the model file + + self.model = torch.jit.load(model_path, map_location="cpu") + self.model.eval() + + def to(self, *args, **kwargs): + self.model = self.model.to(*args, **kwargs) + return self + + @property + def device(self) -> torch.device: + # Checking device for ScriptModule requires checking one of its parameters + params = self.model.parameters() + return next(params).device + + @property + def dtype(self) -> torch.dtype: + # Checking device for ScriptModule requires checking one of its parameters + params = self.model.parameters() + return next(params).dtype + + def __call__( + self, + reader, + writer, + inter_frames: int = 2, + ): + transforms = v2.Compose( + [ + v2.ToDtype(torch.uint8, scale=True), + ] + ) + + writer.open() + + while True: + frame_1 = reader.get_frame() + # If the first frame read is None then there are no more frames + if frame_1 is None: + break + + frame_2 = reader.get_frame() + # If the second frame read is None there there is a final frame + if frame_2 is None: + writer.write_frame(transforms(frame_1)) + break + + # frame_1 and frame_2 must be tensors with n c h w format + frame_1 = frame_1.unsqueeze(0) + frame_2 = frame_2.unsqueeze(0) + + frames = inference( + self.model, frame_1, frame_2, inter_frames, self.device, self.dtype + ) + + frames = [transforms(frame.detach().cpu()) for frame in frames] + for frame in frames: + writer.write_frame(frame) + + writer.close() + + +def inference( + model, img_batch_1, img_batch_2, inter_frames, device, dtype +) -> torch.Tensor: + results = [img_batch_1, img_batch_2] + + idxes = [0, inter_frames + 1] + remains = list(range(1, inter_frames + 1)) + + splits = torch.linspace(0, 1, inter_frames + 2) + + for _ in tqdm(range(len(remains)), "Generating in-between frames"): + starts = splits[idxes[:-1]] + ends = splits[idxes[1:]] + distances = ( + (splits[None, remains] - starts[:, None]) + / (ends[:, None] - starts[:, None]) + - 0.5 + ).abs() + matrix = torch.argmin(distances).item() + start_i, step = np.unravel_index(matrix, distances.shape) + end_i = start_i + 1 + + x0 = results[start_i].to(device=device, dtype=dtype) + x1 = results[end_i].to(device=device, dtype=dtype) + + dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / ( + splits[idxes[end_i]] - splits[idxes[start_i]] + ) + + with torch.no_grad(): + prediction = model(x0, x1, dt) + insert_position = bisect.bisect_left(idxes, remains[step]) + idxes.insert(insert_position, remains[step]) + results.insert(insert_position, prediction.clamp(0, 1).float()) + del remains[step] + + return results \ No newline at end of file diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index 00f34f1d..45eb0c65 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -1,6 +1,7 @@ import logging import os from enum import Enum +import time from typing import List, Optional, Tuple import PIL @@ -29,6 +30,7 @@ logger = logging.getLogger(__name__) +SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled. class ModelName(Enum): """Enumeration mapping model names to their corresponding IDs.""" @@ -146,11 +148,29 @@ def __init__(self, model_id: str): # Warm-up the pipeline. # TODO: Not yet supported for ImageToImagePipeline. if os.getenv("SFAST_WARMUP", "true").lower() == "true": - logger.warning( - "The 'SFAST_WARMUP' flag is not yet supported for the " - "ImageToImagePipeline and will be ignored. As a result the first " - "call may be slow if 'SFAST' is enabled." - ) + warmup_kwargs = { + "prompt":"A warmed up pipeline is a happy pipeline a short poem by ricksta", + "image": PIL.Image.new("RGB", (576, 1024)), + "strength": 0.8, + "negative_prompt": "No blurry or weird artifacts", + "num_images_per_prompt":4, + } + + logger.info("Warming up ImageToImagePipeline pipeline...") + total_time = 0 + for ii in range(SFAST_WARMUP_ITERATIONS): + t = time.time() + try: + self.ldm(**warmup_kwargs).images + except Exception as e: + logger.error(f"ImageToImagePipeline warmup error: {e}") + raise e + iteration_time = time.time() - t + total_time += iteration_time + logger.info( + "Warmup iteration %s took %s seconds", ii + 1, iteration_time + ) + logger.info("Total warmup time: %s seconds", total_time) if deepcache_enabled and not ( is_lightning_model(model_id) or is_turbo_model(model_id) diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index 7b248dcc..46a074ea 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -6,10 +6,14 @@ import PIL import torch from app.pipelines.base import Pipeline -from app.pipelines.utils import SafetyChecker, get_model_dir, get_torch_device from diffusers import StableVideoDiffusionPipeline from huggingface_hub import file_download from PIL import ImageFile +from app.pipelines.utils import ( + SafetyChecker, + get_model_dir, + get_torch_device +) ImageFile.LOAD_TRUNCATED_IMAGES = True diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index b24a3a95..03d89f6f 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -1,6 +1,7 @@ import logging import os from enum import Enum +import time from typing import List, Optional, Tuple import PIL @@ -28,6 +29,7 @@ logger = logging.getLogger(__name__) +SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled. class ModelName(Enum): """Enumeration mapping model names to their corresponding IDs.""" @@ -173,14 +175,31 @@ def __init__(self, model_id: str): self.ldm = compile_model(self.ldm) - # Warm-up the pipeline. - # TODO: Not yet supported for TextToImagePipeline. if os.getenv("SFAST_WARMUP", "true").lower() == "true": - logger.warning( - "The 'SFAST_WARMUP' flag is not yet supported for the " - "TextToImagePipeline and will be ignored. As a result the first " - "call may be slow if 'SFAST' is enabled." - ) + # Retrieve default model params. + # TODO: Retrieve defaults from Pydantic class in route. + warmup_kwargs = { + "prompt": "A happy pipe in the line looking at the wall with words sfast", + "num_images_per_prompt": 4, + "negative_prompt": "No blurry or weird artifacts", + } + + logger.info("Warming up TextToImagePipeline pipeline...") + total_time = 0 + for ii in range(SFAST_WARMUP_ITERATIONS): + t = time.time() + try: + self.ldm(**warmup_kwargs).images + except Exception as e: + # FIXME: When out of memory, pipeline is corrupted. + logger.error(f"TextToImagePipeline warmup error: {e}") + raise e + iteration_time = time.time() - t + total_time += iteration_time + logger.info( + "Warmup iteration %s took %s seconds", ii + 1, iteration_time + ) + logger.info("Total warmup time: %s seconds", total_time) if deepcache_enabled and not ( is_lightning_model(model_id) or is_turbo_model(model_id) diff --git a/runner/app/pipelines/upscale.py b/runner/app/pipelines/upscale.py index 5122a35d..b4978262 100644 --- a/runner/app/pipelines/upscale.py +++ b/runner/app/pipelines/upscale.py @@ -1,5 +1,6 @@ import logging import os +import time from typing import List, Optional, Tuple import PIL @@ -20,6 +21,7 @@ logger = logging.getLogger(__name__) +SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled. class UpscalePipeline(Pipeline): def __init__(self, model_id: str): @@ -67,11 +69,29 @@ def __init__(self, model_id: str): # Warm-up the pipeline. # TODO: Not yet supported for UpscalePipeline. if os.getenv("SFAST_WARMUP", "true").lower() == "true": - logger.warning( - "The 'SFAST_WARMUP' flag is not yet supported for the " - "UpscalePipeline and will be ignored. As a result the first " - "call may be slow if 'SFAST' is enabled." - ) + # Retrieve default model params. + # TODO: Retrieve defaults from Pydantic class in route. + warmup_kwargs = { + "prompt": "Upscaling the pipeline with sfast enabled", + "image": PIL.Image.new("RGB", (400, 400)), # anything higher than this size cause the model to OOM + } + + logger.info("Warming up Upscale pipeline...") + total_time = 0 + for ii in range(SFAST_WARMUP_ITERATIONS): + t = time.time() + try: + self.ldm(**warmup_kwargs).images + except Exception as e: + # FIXME: When out of memory, pipeline is corrupted. + logger.error(f"Upscale pipeline warmup error: {e}") + raise e + iteration_time = time.time() - t + total_time += iteration_time + logger.info( + "Warmup iteration %s took %s seconds", ii + 1, iteration_time + ) + logger.info("Total warmup time: %s seconds", total_time) if deepcache_enabled and not ( is_lightning_model(model_id) or is_turbo_model(model_id) @@ -86,7 +106,7 @@ def __init__(self, model_id: str): elif deepcache_enabled: logger.warning( "DeepCache is not supported for Lightning or Turbo models. " - "TextToImagePipeline will NOT be optimized with DeepCache for %s", + "UpscalingPiepline will NOT be optimized with DeepCache for %s", model_id, ) diff --git a/runner/app/pipelines/utils/__init__.py b/runner/app/pipelines/utils/__init__.py index c992e2c4..d6543349 100644 --- a/runner/app/pipelines/utils/__init__.py +++ b/runner/app/pipelines/utils/__init__.py @@ -11,4 +11,8 @@ is_turbo_model, split_prompt, validate_torch_device, + frames_compactor, + video_shredder, + DirectoryReader, + DirectoryWriter ) diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py index 0f578896..fb5fdf11 100644 --- a/runner/app/pipelines/utils/utils.py +++ b/runner/app/pipelines/utils/utils.py @@ -4,12 +4,19 @@ import os import re from pathlib import Path -from typing import Dict, Optional +from typing import Optional +import glob +import tempfile +from io import BytesIO +from typing import List, Union, Dict, Optional import numpy as np import torch from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from PIL import Image +from torchvision.transforms import v2 +import cv2 +from torchaudio.io import StreamWriter from torch import dtype as TorchDtype from transformers import CLIPImageProcessor @@ -119,6 +126,108 @@ def split_prompt( return prompt_dict +def frames_compactor( + frames: Union[List[np.ndarray], List[torch.Tensor]], + output_path: str, + fps: float, + codec: str = "MJPEG", + is_directory: bool = False, + width: int = None, + height: int = None +) -> None: + """ + Generate a video from a list of frames. Frames can be from a directory or in-memory. + + Args: + frames (List[np.ndarray] | List[torch.Tensor]): List of frames as NumPy arrays or PyTorch tensors. + output_path (str): Path to save the output video file. + fps (float): Frames per second for the video. + codec (str): Codec used for video compression (default is "XVID"). + is_directory (bool): If True, treat `frames` as a directory path containing image files. + width (int): Width of the video. Must be provided if `frames` are in-memory. + height (int): Height of the video. Must be provided if `frames` are in-memory. + + Returns: + None + """ + if is_directory: + # Read frames from a directory + frames = [cv2.imread(os.path.join(frames, file)) for file in sorted(os.listdir(frames))] + else: + # Convert torch tensors to numpy arrays if necessary + if isinstance(frames[0], torch.Tensor): + frames = [frame.permute(1, 2, 0).cpu().numpy() for frame in frames] + + # Ensure frames are numpy arrays and are uint8 type + frames = [frame.astype(np.uint8) for frame in frames] + + # Check if frames are consistent + if not frames: + raise ValueError("No frames to process.") + + if width is None or height is None: + # Use dimensions of the first frame if not provided + height, width = frames[0].shape[:2] + + # Define the codec and create VideoWriter object + fourcc = cv2.VideoWriter_fourcc(*codec) + video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) + + # Write frames to the video file + for frame in frames: + # Ensure each frame has the correct size + if frame.shape[1] != width or frame.shape[0] != height: + frame = cv2.resize(frame, (width, height)) + video_writer.write(frame) + + # Release the video writer + video_writer.release() + +def video_shredder(video_data, is_file_path=True) -> np.ndarray: + """ + Extract frames from a video file or in-memory video data and return them as a NumPy array. + + Args: + video_data (str or BytesIO): Path to the input video file or in-memory video data. + is_file_path (bool): Indicates if video_data is a file path (True) or in-memory data (False). + + Returns: + np.ndarray: Array of frames with shape (num_frames, height, width, channels). + """ + if is_file_path: + # Handle file-based video input + video_capture = cv2.VideoCapture(video_data) + else: + # Handle in-memory video input + # Create a temporary file to store in-memory video data + with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file: + temp_file.write(video_data.getvalue()) + temp_file_path = temp_file.name + + # Open the temporary video file + video_capture = cv2.VideoCapture(temp_file_path) + + if not video_capture.isOpened(): + raise ValueError("Error opening video data") + + frames = [] + success, frame = video_capture.read() + + while success: + frames.append(frame) + success, frame = video_capture.read() + + video_capture.release() + + # Delete the temporary file if it was created + if not is_file_path: + os.remove(temp_file_path) + + # Convert list of frames to a NumPy array + frames_array = np.array(frames) + print(f"Extracted {frames_array.shape[0]} frames from video in shape of {frames_array.shape}") + + return frames_array class SafetyChecker: """Checks images for unsafe or inappropriate content using a pretrained model. @@ -175,3 +284,65 @@ def check_nsfw_images( clip_input=safety_checker_input.pixel_values.to(self._dtype), ) return images, has_nsfw_concept + + +def natural_sort_key(s): + """ + Sort in a natural order, separating strings into a list of strings and integers. + This handles leading zeros and case insensitivity. + """ + return [ + int(text) if text.isdigit() else text.lower() + for text in re.split(r'([0-9]+)', os.path.basename(s)) + ] + +class DirectoryReader: + def __init__(self, dir: str): + self.paths = sorted( + glob.glob(os.path.join(dir, "*")), + key=natural_sort_key + ) + self.nb_frames = len(self.paths) + self.idx = 0 + + assert self.nb_frames > 0, "no frames found in directory" + + first_img = Image.open(self.paths[0]) + self.height = first_img.height + self.width = first_img.width + + def get_resolution(self): + return self.height, self.width + + def reset(self): + self.idx = 0 # Reset the index counter to 0 + + def get_frame(self): + if self.idx >= self.nb_frames: + return None + + path = self.paths[self.idx] + self.idx += 1 + + img = Image.open(path) + transforms = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) + + return transforms(img) + +class DirectoryWriter: + def __init__(self, dir: str): + self.dir = dir + self.idx = 0 + + def open(self): + return + + def close(self): + return + + def write_frame(self, frame: torch.Tensor): + path = f"{self.dir}/{self.idx}.png" + self.idx += 1 + + transforms = v2.Compose([v2.ToPILImage()]) + transforms(frame.squeeze(0)).save(path) \ No newline at end of file diff --git a/runner/app/routes/frame_interpolation.py b/runner/app/routes/frame_interpolation.py new file mode 100644 index 00000000..6abf6a66 --- /dev/null +++ b/runner/app/routes/frame_interpolation.py @@ -0,0 +1,120 @@ +# app/routes/film_interpolate.py + +import logging +import os +import torch +import glob +from typing import Annotated, Optional +from fastapi import APIRouter, Depends, File, Form, UploadFile, status +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from PIL import Image, ImageFile + +from app.dependencies import get_pipeline +from app.pipelines.frame_interpolation import FILMPipeline +from app.pipelines.utils.utils import DirectoryReader, DirectoryWriter, get_torch_device, get_model_dir +from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +router = APIRouter() + +logger = logging.getLogger(__name__) + +RESPONSES = { + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + +@router.post("/frame_interpolation", response_model=ImageResponse, responses=RESPONSES) +@router.post( + "/frame_interpolation/", + response_model=ImageResponse, + responses=RESPONSES, + include_in_schema=False, +) +async def frame_interpolation( + model_id: Annotated[str, Form()], + image1: Annotated[UploadFile, File()]=None, + image2: Annotated[UploadFile, File()]=None, + image_dir: Annotated[str, Form()]="", + inter_frames: Annotated[int, Form()] = 2, + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + + + # Initialize FILMPipeline + film_pipeline = FILMPipeline(model_id) + film_pipeline.to(device=get_torch_device(),dtype=torch.float16) + + # Prepare directories for input and output + temp_input_dir = "temp_input" + temp_output_dir = "temp_output" + os.makedirs(temp_input_dir, exist_ok=True) + os.makedirs(temp_output_dir, exist_ok=True) + + try: + if os.path.isdir(image_dir): + if image1 and image2: + logger.info("Both directory and individual images provided. Directory will be used, and images will be ignored.") + reader = DirectoryReader(image_dir) + else: + if not (image1 and image2): + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error("Either a directory or two images must be provided."), + ) + + image1_path = os.path.join(temp_input_dir, "0.png") + image2_path = os.path.join(temp_input_dir, "1.png") + + with open(image1_path, "wb") as f: + f.write(await image1.read()) + with open(image2_path, "wb") as f: + f.write(await image2.read()) + + reader = DirectoryReader(temp_input_dir) + + writer = DirectoryWriter(temp_output_dir) + # Perform interpolation + film_pipeline(reader, writer, inter_frames=inter_frames) + + writer.close() + reader.reset() + + # Collect output frames + output_frames = [] + for frame_path in sorted(glob.glob(os.path.join(temp_output_dir, "*.png"))): + frame = Image.open(frame_path) + output_frames.append(frame) + + output_images = [{"url": image_to_data_url(frame),"seed":0, "nsfw":False} for frame in output_frames] + + except Exception as e: + logger.error(f"FILMPipeline error: {e}") + logger.exception(e) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=http_error("FILMPipeline error"), + ) + + finally: + # Clean up temporary directories + for file_path in glob.glob(os.path.join(temp_input_dir, "*")): + os.remove(file_path) + os.rmdir(temp_input_dir) + + for file_path in glob.glob(os.path.join(temp_output_dir, "*")): + os.remove(file_path) + os.rmdir(temp_output_dir) + + return {"images": output_images} diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index f8b9c613..1dbf44bc 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -11,6 +11,7 @@ image_to_video, segment_anything_2, text_to_image, + frame_interpolation, upscale, ) from fastapi.openapi.utils import get_openapi @@ -104,7 +105,7 @@ def translate_to_gateway(openapi: dict) -> dict: openapi["components"]["schemas"]["VideoResponse"]["title"] = "VideoResponse" return openapi - + def write_openapi(fname: str, entrypoint: str = "runner", version: str = "0.0.0"): """Write OpenAPI schema to file. @@ -120,8 +121,8 @@ def write_openapi(fname: str, entrypoint: str = "runner", version: str = "0.0.0" app.include_router(text_to_image.router) app.include_router(image_to_image.router) app.include_router(image_to_video.router) - app.include_router(upscale.router) app.include_router(audio_to_text.router) + app.include_router(upscale.router) app.include_router(segment_anything_2.router) logger.info(f"Generating OpenAPI schema for '{entrypoint}' entrypoint...") diff --git a/runner/requirements.txt b/runner/requirements.txt index 24f2442f..8d1fd44b 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -1,11 +1,15 @@ diffusers==0.30.0 accelerate==0.30.1 -transformers==4.41.1 +transformers==4.43.1 fastapi==0.111.0 pydantic==2.7.2 Pillow==10.3.0 python-multipart==0.0.9 uvicorn==0.30.0 +setuptools==71.1.0 +torch --index-url https://download.pytorch.org/whl/cu121 +torchvision --index-url https://download.pytorch.org/whl/cu121 +torchaudio --index-url https://download.pytorch.org/whl/cu121 huggingface_hub==0.23.2 xformers==0.0.23 triton>=2.1.0 @@ -17,3 +21,4 @@ numpy==1.26.4 av==12.1.0 sentencepiece== 0.2.0 protobuf==5.27.2 +opencv-python==4.10.0.84