Skip to content

Commit

Permalink
Support both old and new diffusers import path (#843)
Browse files Browse the repository at this point in the history
* Update modeling_sd_base.py

* Update trl/models/modeling_sd_base.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* make precommit

* cleaner approach

* oops

* better alternative

* rm uneeded file

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
  • Loading branch information
3 people authored Oct 12, 2023
1 parent 3ef21a2 commit 96d4854
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
51 changes: 50 additions & 1 deletion trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import random
import warnings
from contextlib import contextmanager
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -266,3 +267,51 @@ def empty_cuda_cache(cls):
gc.collect()
torch.cuda.empty_cache()
gc.collect()


def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional["torch.device"] = None,
dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None,
):
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
is always created on the CPU.
"""
# device on which tensor is created defaults to device
rand_device = device
batch_size = shape[0]

layout = layout or torch.strided
device = device or torch.device("cpu")

if generator is not None:
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
warnings.warn(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")

# make sure generator list of length 1 is treated like a non-list
if isinstance(generator, list) and len(generator) == 1:
generator = generator[0]

if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)

return latents
3 changes: 2 additions & 1 deletion trl/models/modeling_sd_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from diffusers.utils.torch_utils import randn_tensor

from ..core import randn_tensor


@dataclass
Expand Down

0 comments on commit 96d4854

Please sign in to comment.