Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sfast optimization for T2I , I2I and Upscale models #134

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
fd1ab58
update to the frame interpolation pipeline, there is some minor issue…
jjassonn Jul 25, 2024
36f4ea5
Merge branch 'frame-interpolation' of https://github.com/JJassonn69/a…
jjassonn Jul 25, 2024
65d2c53
minor changes to requirements
jjassonn Jul 26, 2024
796796d
update to requrements to fetch from --index-url
jjassonn Jul 26, 2024
b5eb66d
simple patch to solve the go api bindings issue
jjassonn Jul 26, 2024
522ca4f
checking if it works in my system
jjassonn Jul 26, 2024
c1b5ca1
Create docker-image.yml
JJassonn69 Jul 26, 2024
c9bf8d2
Delete .github/workflows/docker-image.yml
JJassonn69 Jul 26, 2024
d1b5d3c
Create validate-openapi-on-push.yaml
JJassonn69 Jul 26, 2024
8f82e52
Create docker-create-ai-runner.yaml
JJassonn69 Jul 26, 2024
6d1d32a
Update docker-create-ai-runner.yaml
JJassonn69 Jul 26, 2024
4812c2a
Frame interpolation (#5)
JJassonn69 Jul 26, 2024
9db6034
Update docker-create-ai-runner.yaml
JJassonn69 Jul 26, 2024
3ed7792
Delete .github/workflows/ai-runner-docker.yaml
JJassonn69 Jul 26, 2024
d032b0f
Delete .github/workflows/validate-openapi-on-pr.yaml
JJassonn69 Jul 26, 2024
9606ba6
Update docker-create-ai-runner.yaml
JJassonn69 Jul 26, 2024
32ebe0f
Update validate-openapi-on-push.yaml
JJassonn69 Jul 26, 2024
ce33b20
Update trigger-upstream-openapi-sync.yaml
JJassonn69 Jul 26, 2024
ca1d2f2
chore(deps): bump github.com/docker/docker (#6)
dependabot[bot] Jul 26, 2024
56d802d
Update docker-create-ai-runner.yaml
JJassonn69 Jul 26, 2024
01f229a
Update docker-create-ai-runner.yaml
JJassonn69 Jul 26, 2024
580224c
Update docker-create-ai-runner.yaml
JJassonn69 Jul 26, 2024
4f9f578
Update docker-create-ai-runner.yaml
JJassonn69 Jul 26, 2024
fa07ec6
Frame interpolation merging from branch (#7)
JJassonn69 Jul 26, 2024
672f5fd
Delete .github/workflows/trigger-upstream-openapi-sync.yaml
JJassonn69 Jul 26, 2024
ef155aa
test-examples for frame-interpolation
jjassonn Jul 26, 2024
950cdf9
update to sfast optimization to i2i and t2i and upscale pipelines
jjassonn Jul 28, 2024
b48692b
changes to extra files
jjassonn Jul 28, 2024
2d25c46
added git ignore to the files to remove unnecessary files
jjassonn Jul 28, 2024
eb5ae46
files not removed checking again
jjassonn Jul 28, 2024
e9b3965
still in test phase
jjassonn Jul 28, 2024
19261b6
test-test
jjassonn Jul 28, 2024
cb6498e
Update .gitignore
JJassonn69 Jul 28, 2024
b91da52
Delete runner/app/tests-examples directory
JJassonn69 Jul 28, 2024
12a925e
update to directory reader as it now reads almost any naming convention
jjassonn Jul 29, 2024
be10090
Merge branch 'sfast_optimization' of https://github.com/JJassonn69/ai…
jjassonn Jul 29, 2024
8b56e80
Merge branch 'main' into sfast_optimization
JJassonn69 Jul 29, 2024
cd45ee0
changing files similar to main to make easy to merge
jjassonn Jul 30, 2024
0885718
update to upscale warmup params to fix OOM
jjassonn Jul 30, 2024
3de64b3
naming wrong in the info section for error msg
jjassonn Jul 31, 2024
d58b4dc
Merge branch 'main' into pr/134
JJassonn69 Sep 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ai-runner-docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ jobs:
file: "Dockerfile"
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=livepeerci/build:cache
cache-to: type=registry,ref=livepeerci/build:cache,mode=max
cache-to: type=registry,ref=livepeerci/build:cache,mode=max
2 changes: 1 addition & 1 deletion .github/workflows/validate-openapi-on-pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ jobs:
if ! git diff --exit-code; then
echo "::error::Go bindings have changed. Please run 'make' at the root of the repository and commit the changes."
exit 1
fi
fi
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.21.5
require (
github.com/deepmap/oapi-codegen/v2 v2.2.0
github.com/docker/cli v24.0.5+incompatible
github.com/docker/docker v24.0.7+incompatible
github.com/docker/docker v24.0.9+incompatible
github.com/docker/go-connections v0.4.0
github.com/getkin/kin-openapi v0.124.0
github.com/go-chi/chi/v5 v5.0.12
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ github.com/docker/cli v24.0.5+incompatible h1:WeBimjvS0eKdH4Ygx+ihVq1Q++xg36M/rM
github.com/docker/cli v24.0.5+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk=
github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
github.com/docker/docker v24.0.7+incompatible h1:Wo6l37AuwP3JaMnZa226lzVXGA3F9Ig1seQen0cKYlM=
github.com/docker/docker v24.0.7+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/docker v24.0.9+incompatible h1:HPGzNmwfLZWdxHqK9/II92pyi1EpYKsAqcl4G0Of9v0=
github.com/docker/docker v24.0.9+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ=
github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec=
github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=
Expand Down
18 changes: 12 additions & 6 deletions runner/app/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
import sys
import cv2
from contextlib import asynccontextmanager

from app.routes import health
Expand All @@ -15,8 +17,8 @@ async def lifespan(app: FastAPI):

app.include_router(health.router)

pipeline = os.environ["PIPELINE"]
model_id = os.environ["MODEL_ID"]
pipeline = os.environ.get("PIPELINE", "") # Default to
model_id = os.environ.get("MODEL_ID", "") # Provide a default if necessary

app.pipeline = load_pipeline(pipeline, model_id)
app.include_router(load_route(pipeline))
Expand Down Expand Up @@ -44,8 +46,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.audio_to_text import AudioToTextPipeline

return AudioToTextPipeline(model_id)
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "FILMPipeline":
from app.pipelines.frame_interpolation import FILMPipeline

return FILMPipeline(model_id)
case "upscale":
from app.pipelines.upscale import UpscalePipeline

Expand Down Expand Up @@ -78,8 +82,10 @@ def load_route(pipeline: str) -> any:
from app.routes import audio_to_text

return audio_to_text.router
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "FILMPipeline":
from app.routes import frame_interpolation

return frame_interpolation.router
case "upscale":
from app.routes import upscale

Expand Down
114 changes: 111 additions & 3 deletions runner/app/pipelines/frame_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,113 @@
from app.pipelines.base import Pipeline
import torch
from torchvision.transforms import v2
from tqdm import tqdm
import bisect
import numpy as np
from app.pipelines.utils.utils import get_model_dir


class FrameInterpolationPipeline(Pipeline):
pass
class FILMPipeline:
model: torch.jit.ScriptModule

def __init__(self, model_id: str):
self.model_id = model_id
model_dir = get_model_dir() # Get the directory where models are stored
model_path = f"{model_dir}/{model_id}" # Construct the full path to the model file

self.model = torch.jit.load(model_path, map_location="cpu")
self.model.eval()

def to(self, *args, **kwargs):
self.model = self.model.to(*args, **kwargs)
return self

@property
def device(self) -> torch.device:
# Checking device for ScriptModule requires checking one of its parameters
params = self.model.parameters()
return next(params).device

@property
def dtype(self) -> torch.dtype:
# Checking device for ScriptModule requires checking one of its parameters
params = self.model.parameters()
return next(params).dtype

def __call__(
self,
reader,
writer,
inter_frames: int = 2,
):
transforms = v2.Compose(
[
v2.ToDtype(torch.uint8, scale=True),
]
)

writer.open()

while True:
frame_1 = reader.get_frame()
# If the first frame read is None then there are no more frames
if frame_1 is None:
break

frame_2 = reader.get_frame()
# If the second frame read is None there there is a final frame
if frame_2 is None:
writer.write_frame(transforms(frame_1))
break

# frame_1 and frame_2 must be tensors with n c h w format
frame_1 = frame_1.unsqueeze(0)
frame_2 = frame_2.unsqueeze(0)

frames = inference(
self.model, frame_1, frame_2, inter_frames, self.device, self.dtype
)

frames = [transforms(frame.detach().cpu()) for frame in frames]
for frame in frames:
writer.write_frame(frame)

writer.close()


def inference(
model, img_batch_1, img_batch_2, inter_frames, device, dtype
) -> torch.Tensor:
results = [img_batch_1, img_batch_2]

idxes = [0, inter_frames + 1]
remains = list(range(1, inter_frames + 1))

splits = torch.linspace(0, 1, inter_frames + 2)

for _ in tqdm(range(len(remains)), "Generating in-between frames"):
starts = splits[idxes[:-1]]
ends = splits[idxes[1:]]
distances = (
(splits[None, remains] - starts[:, None])
/ (ends[:, None] - starts[:, None])
- 0.5
).abs()
matrix = torch.argmin(distances).item()
start_i, step = np.unravel_index(matrix, distances.shape)
end_i = start_i + 1

x0 = results[start_i].to(device=device, dtype=dtype)
x1 = results[end_i].to(device=device, dtype=dtype)

dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (
splits[idxes[end_i]] - splits[idxes[start_i]]
)

with torch.no_grad():
prediction = model(x0, x1, dt)
insert_position = bisect.bisect_left(idxes, remains[step])
idxes.insert(insert_position, remains[step])
results.insert(insert_position, prediction.clamp(0, 1).float())
del remains[step]

return results
30 changes: 25 additions & 5 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from enum import Enum
import time
from typing import List, Optional, Tuple

import PIL
Expand Down Expand Up @@ -29,6 +30,7 @@

logger = logging.getLogger(__name__)

SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled.

class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""
Expand Down Expand Up @@ -146,11 +148,29 @@ def __init__(self, model_id: str):
# Warm-up the pipeline.
# TODO: Not yet supported for ImageToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"ImageToImagePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)
warmup_kwargs = {
"prompt":"A warmed up pipeline is a happy pipeline a short poem by ricksta",
"image": PIL.Image.new("RGB", (576, 1024)),
"strength": 0.8,
"negative_prompt": "No blurry or weird artifacts",
"num_images_per_prompt":4,
}

logger.info("Warming up ImageToImagePipeline pipeline...")
total_time = 0
for ii in range(SFAST_WARMUP_ITERATIONS):
t = time.time()
try:
self.ldm(**warmup_kwargs).images
except Exception as e:
logger.error(f"ImageToImagePipeline warmup error: {e}")
raise e
iteration_time = time.time() - t
total_time += iteration_time
logger.info(
"Warmup iteration %s took %s seconds", ii + 1, iteration_time
)
logger.info("Total warmup time: %s seconds", total_time)

if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
Expand Down
6 changes: 5 additions & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import SafetyChecker, get_model_dir, get_torch_device
from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
from PIL import ImageFile
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device
)

ImageFile.LOAD_TRUNCATED_IMAGES = True

Expand Down
33 changes: 26 additions & 7 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from enum import Enum
import time
from typing import List, Optional, Tuple

import PIL
Expand Down Expand Up @@ -28,6 +29,7 @@

logger = logging.getLogger(__name__)

SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled.

class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""
Expand Down Expand Up @@ -173,14 +175,31 @@ def __init__(self, model_id: str):

self.ldm = compile_model(self.ldm)

# Warm-up the pipeline.
# TODO: Not yet supported for TextToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"TextToImagePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)
# Retrieve default model params.
# TODO: Retrieve defaults from Pydantic class in route.
warmup_kwargs = {
"prompt": "A happy pipe in the line looking at the wall with words sfast",
"num_images_per_prompt": 4,
"negative_prompt": "No blurry or weird artifacts",
}

logger.info("Warming up TextToImagePipeline pipeline...")
total_time = 0
for ii in range(SFAST_WARMUP_ITERATIONS):
t = time.time()
try:
self.ldm(**warmup_kwargs).images
except Exception as e:
# FIXME: When out of memory, pipeline is corrupted.
logger.error(f"TextToImagePipeline warmup error: {e}")
raise e
iteration_time = time.time() - t
total_time += iteration_time
logger.info(
"Warmup iteration %s took %s seconds", ii + 1, iteration_time
)
logger.info("Total warmup time: %s seconds", total_time)

if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
Expand Down
32 changes: 26 additions & 6 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import time
from typing import List, Optional, Tuple

import PIL
Expand All @@ -20,6 +21,7 @@

logger = logging.getLogger(__name__)

SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled.

class UpscalePipeline(Pipeline):
def __init__(self, model_id: str):
Expand Down Expand Up @@ -67,11 +69,29 @@ def __init__(self, model_id: str):
# Warm-up the pipeline.
# TODO: Not yet supported for UpscalePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"UpscalePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)
# Retrieve default model params.
# TODO: Retrieve defaults from Pydantic class in route.
warmup_kwargs = {
"prompt": "Upscaling the pipeline with sfast enabled",
"image": PIL.Image.new("RGB", (400, 400)), # anything higher than this size cause the model to OOM
}

logger.info("Warming up Upscale pipeline...")
total_time = 0
for ii in range(SFAST_WARMUP_ITERATIONS):
t = time.time()
try:
self.ldm(**warmup_kwargs).images
except Exception as e:
# FIXME: When out of memory, pipeline is corrupted.
logger.error(f"Upscale pipeline warmup error: {e}")
raise e
iteration_time = time.time() - t
total_time += iteration_time
logger.info(
"Warmup iteration %s took %s seconds", ii + 1, iteration_time
)
logger.info("Total warmup time: %s seconds", total_time)

if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
Expand All @@ -86,7 +106,7 @@ def __init__(self, model_id: str):
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"TextToImagePipeline will NOT be optimized with DeepCache for %s",
"UpscalingPiepline will NOT be optimized with DeepCache for %s",
model_id,
)

Expand Down
4 changes: 4 additions & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@
is_turbo_model,
split_prompt,
validate_torch_device,
frames_compactor,
video_shredder,
DirectoryReader,
DirectoryWriter
)
Loading
Loading