Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ImportError: cannot import name 'xla_call_p' from 'jax.interpreters.xla' #16214

Closed
hlcxq1223 opened this issue Jun 1, 2023 · 2 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@hlcxq1223
Copy link

hlcxq1223 commented Jun 1, 2023

Description

When i ran the below code,it came a import error

import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import cell2location
from matplotlib import rcParams
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text for PDFs
Global seed set to 0
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[2], line 6
      3 import matplotlib.pyplot as plt
      4 import matplotlib as mpl
----> 6 import cell2location
      8 from matplotlib import rcParams
      9 rcParams['pdf.fonttype'] = 42 # enables correct plotting of text for PDFs

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/cell2location/__init__.py:9
      6 from rich.logging import RichHandler
      7 from torch.distributions import biject_to, transform_to
----> 9 from . import models
     10 from .run_colocation import run_colocation
     12 # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
     13 # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/cell2location/models/__init__.py:1
----> 1 from ._cell2location_model import Cell2location
      2 from ._cell2location_module import (
      3     LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel,
      4 )
      5 from ._cell2location_WTA_model import Cell2location_WTA

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/cell2location/models/_cell2location_model.py:11
      9 from pyro.infer import Trace_ELBO, TraceEnum_ELBO
     10 from pyro.nn import PyroModule
---> 11 from scvi import REGISTRY_KEYS
     12 from scvi.data import AnnDataManager
     13 from scvi.data.fields import (
     14     CategoricalJointObsField,
     15     CategoricalObsField,
   (...)
     18     NumericalObsField,
     19 )

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/__init__.py:10
      7 from ._settings import settings
      9 # this import needs to come after prior imports to prevent circular import
---> 10 from . import autotune, data, model, external, utils
     12 from importlib.metadata import version
     14 package_name = "scvi-tools"

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/autotune/__init__.py:1
----> 1 from ._manager import TuneAnalysis, TunerManager
      2 from ._tuner import ModelTuner
      3 from ._types import Tunable, TunableMixin

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/autotune/_manager.py:22
     20 from scvi._types import AnnOrMuData
     21 from scvi.data._constants import _SETUP_ARGS_KEY, _SETUP_METHOD_NAME
---> 22 from scvi.model.base import BaseModelClass
     23 from scvi.utils import InvalidParameterError
     25 from ._defaults import COLORS, COLUMN_KWARGS, DEFAULTS, TUNABLE_TYPES

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/model/__init__.py:2
      1 from . import utils
----> 2 from ._amortizedlda import AmortizedLDA
      3 from ._autozi import AUTOZI
      4 from ._condscvi import CondSCVI

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/model/_amortizedlda.py:14
     12 from scvi.data import AnnDataManager
     13 from scvi.data.fields import LayerField
---> 14 from scvi.module import AmortizedLDAPyroModule
     15 from scvi.utils import setup_anndata_dsp
     17 from .base import BaseModelClass, PyroSviTrainMixin

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/module/__init__.py:1
----> 1 from ._amortizedlda import AmortizedLDAPyroModule
      2 from ._autozivae import AutoZIVAE
      3 from ._classifier import Classifier

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/module/_amortizedlda.py:14
     12 from scvi._constants import REGISTRY_KEYS
     13 from scvi.autotune._types import Tunable
---> 14 from scvi.module.base import PyroBaseModuleClass, auto_move_data
     15 from scvi.nn import Encoder
     17 _AMORTIZED_LDA_PYRO_MODULE_NAME = "amortized_lda"

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/module/base/__init__.py:1
----> 1 from ._base_module import (
      2     BaseMinifiedModeModuleClass,
      3     BaseModuleClass,
      4     JaxBaseModuleClass,
      5     LossOutput,
      6     PyroBaseModuleClass,
      7     TrainStateWithState,
      8 )
      9 from ._decorators import auto_move_data, flax_configure
     11 __all__ = [
     12     "BaseModuleClass",
     13     "LossOutput",
   (...)
     19     "BaseMinifiedModeModuleClass",
     20 ]

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/scvi/module/base/_base_module.py:18
     16 from jax import random
     17 from jaxlib.xla_extension import Device
---> 18 from numpyro.distributions import Distribution
     19 from pyro.infer.predictive import Predictive
     20 from torch import nn

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/numpyro/__init__.py:6
      1 # Copyright Contributors to the Pyro project.
      2 # SPDX-License-Identifier: Apache-2.0
      4 import logging
----> 6 from numpyro import compat, diagnostics, distributions, handlers, infer, ops, optim
      7 from numpyro.distributions.distribution import enable_validation, validation_enabled
      8 from numpyro.infer.inspect import render_model

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/numpyro/infer/__init__.py:5
      1 # Copyright Contributors to the Pyro project.
      2 # SPDX-License-Identifier: Apache-2.0
      4 from numpyro.infer.barker import BarkerMH
----> 5 from numpyro.infer.elbo import (
      6     ELBO,
      7     RenyiELBO,
      8     Trace_ELBO,
      9     TraceEnum_ELBO,
     10     TraceGraph_ELBO,
     11     TraceMeanField_ELBO,
     12 )
     13 from numpyro.infer.hmc import HMC, NUTS
     14 from numpyro.infer.hmc_gibbs import HMCECS, DiscreteHMCGibbs, HMCGibbs

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/numpyro/infer/elbo.py:24
     17 from numpyro.handlers import replay, seed, substitute, trace
     18 from numpyro.infer.util import (
     19     _without_rsample_stop_gradient,
     20     get_importance_trace,
     21     is_identically_one,
     22     log_density,
     23 )
---> 24 from numpyro.ops.provenance import eval_provenance
     25 from numpyro.util import _validate_model, check_model_guide_match, find_stack_level
     28 class ELBO:

File ~/miniconda3/envs/cell2loc_env/lib/python3.9/site-packages/numpyro/ops/provenance.py:10
      8 from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
      9 from jax.interpreters.pxla import xla_pmap_p
---> 10 from jax.interpreters.xla import xla_call_p
     11 import jax.linear_util as lu
     12 import jax.numpy as jnp

**ImportError: cannot import name 'xla_call_p' from 'jax.interpreters.xla'** 

Is this an error caused by the latest version v0.4.11? Thanks a lot!

What jax/jaxlib version are you using?

jax v0.4.11;jaxlib v0.4.11

Which accelerator(s) are you using?

CPU

Additional system info

Linux

NVIDIA GPU info

No response

@hlcxq1223 hlcxq1223 added the bug Something isn't working label Jun 1, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 1, 2023

Hi - this is expected. xla_call_p has been deprecated since JAX v0.4.4, and was removed in JAX version 0.4.11. See the Change Log for more information.

If you have code that still uses xla_call_p, you can install JAX version 0.4.10 or older, but I'd suggest updating your code if possible.

@jakevdp jakevdp closed this as completed Jun 1, 2023
@jakevdp jakevdp self-assigned this Jun 1, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 1, 2023

By the way, it looks like numpyro has already fixed the issue here: pyro-ppl/numpyro#1595

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants