Skip to content

Commit

Permalink
Pipeline to device (open-mmlab#210)
Browse files Browse the repository at this point in the history
* Implement `pipeline.to(device)`

* DiffusionPipeline.to() decides best device on None.

* Breaking change: torch_device removed from __call__

`pipeline.to()` now has PyTorch semantics.

* Use kwargs and deprecation notice

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Apply torch_device compatibility to all pipelines.

* style

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: anton-l <anton@huggingface.co>
  • Loading branch information
3 people committed Aug 19, 2022
1 parent 89e9521 commit 71ba8ae
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 60 deletions.
22 changes: 22 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import os
from typing import Optional, Union

import torch

from huggingface_hub import snapshot_download
from PIL import Image

Expand Down Expand Up @@ -113,6 +115,26 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
save_method = getattr(sub_model, save_method_name)
save_method(os.path.join(save_directory, pipeline_component_name))

def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
return self

module_names, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
module.to(torch_device)
return self

@property
def device(self) -> torch.device:
module_names, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
return module.device
return torch.device("cpu")

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Expand Down
25 changes: 17 additions & 8 deletions src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.


import warnings

import torch

from tqdm.auto import tqdm
Expand All @@ -28,21 +30,28 @@ def __init__(self, unet, scheduler):
self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):

if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)

self.unet.to(torch_device)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)

# eta corresponds to η in paper and should be between [0, 1]

# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)
image = image.to(self.device)

# set step values
self.scheduler.set_timesteps(num_inference_steps)
Expand Down
19 changes: 14 additions & 5 deletions src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.


import warnings

import torch

from tqdm.auto import tqdm
Expand All @@ -28,18 +30,25 @@ def __init__(self, unet, scheduler):
self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)

self.unet.to(torch_device)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)

# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)
image = image.to(self.device)

# set step values
self.scheduler.set_timesteps(1000)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import warnings
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -31,13 +32,22 @@ def __call__(
guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
# eta corresponds to η in paper and should be between [0, 1]

if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)

# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)

if isinstance(prompt, str):
batch_size = 1
Expand All @@ -49,24 +59,20 @@ def __call__(
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)

# get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0]
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]

# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]

latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
latents = latents.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import warnings

import torch

Expand All @@ -14,22 +15,26 @@ def __init__(self, vqvae, unet, scheduler):
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
):
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
# eta corresponds to η in paper and should be between [0, 1]

if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)

self.unet.to(torch_device)
self.vqvae.to(torch_device)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)

latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
latents = latents.to(torch_device)
latents = latents.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)

Expand Down
20 changes: 15 additions & 5 deletions src/diffusers/pipelines/pndm/pipeline_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.


import warnings

import torch

from tqdm.auto import tqdm
Expand All @@ -28,20 +30,28 @@ def __init__(self, unet, scheduler):
self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"):
def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_type="pil", **kwargs):
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

self.unet.to(torch_device)
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)

# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)

# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)
image = image.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps):
Expand Down
24 changes: 17 additions & 7 deletions src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
import warnings

import torch

from diffusers import DiffusionPipeline
Expand All @@ -11,24 +13,32 @@ def __init__(self, unet, scheduler):
self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):

if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs):
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)

# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)

img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size)

model = self.unet.to(torch_device)
model = self.unet

sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(torch_device)
sample = sample.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)

for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)

# correction step
for _ in range(self.scheduler.correct_steps):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import warnings
from typing import List, Optional, Union

import torch
Expand Down Expand Up @@ -45,11 +46,20 @@ def __call__(
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)

# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)

if isinstance(prompt, str):
batch_size = 1
Expand All @@ -61,11 +71,6 @@ def __call__(
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

self.unet.to(torch_device)
self.vae.to(torch_device)
self.text_encoder.to(torch_device)
self.safety_checker.to(torch_device)

# get prompt text embeddings
text_input = self.tokenizer(
prompt,
Expand All @@ -74,7 +79,7 @@ def __call__(
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand All @@ -86,7 +91,7 @@ def __call__(
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
Expand All @@ -97,7 +102,7 @@ def __call__(
latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
device=torch_device,
device=self.device,
)

# set timesteps
Expand Down Expand Up @@ -150,7 +155,7 @@ def __call__(
image = image.cpu().permute(0, 2, 3, 1).numpy()

# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device)
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)

if output_type == "pil":
Expand Down
Loading

0 comments on commit 71ba8ae

Please sign in to comment.