From 96d4854455779576623d90c661f8c16b0247546e Mon Sep 17 00:00:00 2001 From: Omar Sanseviero Date: Thu, 12 Oct 2023 15:06:09 +0200 Subject: [PATCH] Support both old and new diffusers import path (#843) * Update modeling_sd_base.py * Update trl/models/modeling_sd_base.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * make precommit * cleaner approach * oops * better alternative * rm uneeded file --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: younesbelkada --- trl/core.py | 51 +++++++++++++++++++++++++++++++++- trl/models/modeling_sd_base.py | 3 +- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/trl/core.py b/trl/core.py index cd01a83546..c86bf3718f 100644 --- a/trl/core.py +++ b/trl/core.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import gc import random +import warnings from contextlib import contextmanager +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -266,3 +267,51 @@ def empty_cuda_cache(cls): gc.collect() torch.cuda.empty_cache() gc.collect() + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + warnings.warn( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents diff --git a/trl/models/modeling_sd_base.py b/trl/models/modeling_sd_base.py index 4e0600df00..0d68380401 100644 --- a/trl/models/modeling_sd_base.py +++ b/trl/models/modeling_sd_base.py @@ -24,7 +24,8 @@ from diffusers.loaders import AttnProcsLayers from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg -from diffusers.utils.torch_utils import randn_tensor + +from ..core import randn_tensor @dataclass