Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'jax.random' has no attribute 'KeyArray' #3

Open
VVh5912 opened this issue May 3, 2024 · 0 comments
Open

AttributeError: module 'jax.random' has no attribute 'KeyArray' #3

VVh5912 opened this issue May 3, 2024 · 0 comments

Comments

@VVh5912
Copy link

VVh5912 commented May 3, 2024

Hi, when i run the interferences (both text-to-image and style transfer) I encounter this error:

2024-05-03 06:36:47.216259: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-03 06:36:47.216309: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-03 06:36:47.217764: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-03 06:36:48.362862: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /content/gdrive/MyDrive/dreamstyler/dreamstyler/inference_style_transfer.py:12 in │
│ │
│ 9 import imageio │
│ 10 import numpy as np │
│ 11 from PIL import Image │
│ ❱ 12 from diffusers import ControlNetModel, UniPCMultistepScheduler │
│ 13 from transformers import CLIPTextModel, CLIPTokenizer │
│ 14 from controlnet_aux.processor import Processor │
│ 15 import custom_pipelines │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/init.py:38 in │
│ │
│ 35 except OptionalDependencyNotAvailable: │
│ 36 │ from .utils.dummy_pt_objects import * # noqa F403 │
│ 37 else: │
│ ❱ 38 │ from .models import ( │
│ 39 │ │ AsymmetricAutoencoderKL, │
│ 40 │ │ AutoencoderKL, │
│ 41 │ │ AutoencoderTiny, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/init.py:36 in │
│ │
│ 33 │ from .vq_model import VQModel │
│ 34 │
│ 35 if is_flax_available(): │
│ ❱ 36 │ from .controlnet_flax import FlaxControlNetModel │
│ 37 │ from .unet_2d_condition_flax import FlaxUNet2DConditionModel │
│ 38 │ from .vae_flax import FlaxAutoencoderKL │
│ 39 │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py:25 in │
│ │
│ 22 from ..configuration_utils import ConfigMixin, flax_register_to_config │
│ 23 from ..utils import BaseOutput │
│ 24 from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps │
│ ❱ 25 from .modeling_flax_utils import FlaxModelMixin │
│ 26 from .unet_2d_blocks_flax import ( │
│ 27 │ FlaxCrossAttnDownBlock2D, │
│ 28 │ FlaxDownBlock2D, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py:46 in │
│ │
│ 43 logger = logging.get_logger(name) │
│ 44 │
│ 45 │
│ ❱ 46 class FlaxModelMixin(PushToHubMixin): │
│ 47 │ r""" │
│ 48 │ Base class for all Flax models. │
│ 49 │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py:195 in │
│ FlaxModelMixin │
│ │
│ 192 │ │ ```""" │
│ 193 │ │ return self._cast_floating_to(params, jnp.float16, mask) │
│ 194 │ │
│ ❱ 195 │ def init_weights(self, rng: jax.random.KeyArray) -> Dict: │
│ 196 │ │ raise NotImplementedError(f"init_weights method has to be implemented for {self} │
│ 197 │ │
│ 198 │ @classmethod
│ │
│ /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py:54 in getattr │
│ │
│ 51 │ │ raise AttributeError(message) │
│ 52 │ warnings.warn(message, DeprecationWarning, stacklevel=2) │
│ 53 │ return fn │
│ ❱ 54 │ raise AttributeError(f"module {module!r} has no attribute {name!r}") │
│ 55 │
│ 56 return getattr │
│ 57 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: module 'jax.random' has no attribute 'KeyArray'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant