Skip to content

Trainer's .init_module() context does not initialize model on target device #20307

@jin-zhe

Description

@jin-zhe

Bug description

I refer to the documentation on https://lightning.ai/docs/pytorch/stable/advanced/model_init.html which states "you can force PyTorch to create the model directly on the target device" when using the .init_module() context. However I have verified across different GPU machines that this is not the case. A simple code is provided below which prints out the model's device after initialization under the context. It always prints 'cpu'.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

from torch import nn, optim
from pytorch_lightning import Trainer, LightningModule


class LitAutoEncoder(LightningModule):
  '''
  Model taken from https://lightning.ai/docs/pytorch/stable/starter/introduction.html
  Details unimportant
  '''
  def __init__(self):
    super().__init__()
    self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
    self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

  def training_step(self, batch, batch_idx):
    x, _ = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    loss = nn.functional.mse_loss(x_hat, x)
    return loss

  def configure_optimizers(self):
    optimizer = optim.Adam(self.parameters(), lr=1e-3)
    return optimizer

trainer = Trainer(accelerator='gpu', devices=[0])
with trainer.init_module():
  model = LitAutoEncoder()
  print(model.device) # => cpu

Error messages and logs

# Error messages and logs here please

Environment

Current environment ` * CUDA: - GPU: - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB-LS - available: True - version: 11.8 * Lightning: - lightning: 2.4.0 - lightning-utilities: 0.11.6 - open-clip-torch: 2.26.1 - pytorch-lightning: 2.4.0 - torch: 2.1.0 - torchaudio: 2.1.0 - torchmetrics: 1.4.0.post0 - torchvision: 0.16.0 * Packages: - aiohappyeyeballs: 2.3.5 - aiohttp: 3.10.3 - aiosignal: 1.3.1 - altair: 5.4.1 - antlr4-python3-runtime: 4.9.3 - appdirs: 1.4.4 - asttokens: 2.4.1 - async-timeout: 4.0.3 - attrs: 24.2.0 - autocommand: 2.2.2 - backports.tarfile: 1.2.0 - blinker: 1.8.2 - brotli: 1.1.0 - cachetools: 5.5.0 - certifi: 2024.8.30 - cffi: 1.17.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - colorama: 0.4.6 - comm: 0.2.2 - datasets: 2.20.0 - debugpy: 1.8.5 - decorator: 5.1.1 - dill: 0.3.8 - docker-pycreds: 0.4.0 - einops: 0.8.0 - exceptiongroup: 1.2.2 - executing: 2.1.0 - filelock: 3.15.4 - frozenlist: 1.4.1 - fsspec: 2024.5.0 - ftfy: 6.2.3 - gitdb: 4.0.11 - gitpython: 3.1.43 - gmpy2: 2.1.5 - h2: 4.1.0 - hpack: 4.0.0 - huggingface-hub: 0.24.5 - hyperframe: 6.0.1 - idna: 3.7 - importlib-metadata: 7.2.1 - importlib-resources: 6.4.5 - inflect: 7.3.1 - ipykernel: 6.29.5 - ipython: 8.27.0 - ipywidgets: 8.1.5 - jaraco.context: 5.3.0 - jaraco.functools: 4.0.1 - jaraco.text: 3.12.1 - jedi: 0.19.1 - jinja2: 3.1.4 - jsonlines: 4.0.0 - jsonschema: 4.23.0 - jsonschema-specifications: 2023.12.1 - jupyter-client: 8.6.3 - jupyter-core: 5.7.2 - jupyterlab-widgets: 3.0.13 - lightning: 2.4.0 - lightning-utilities: 0.11.6 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - matplotlib-inline: 0.1.7 - mdurl: 0.1.2 - more-itertools: 10.3.0 - mpmath: 1.3.0 - multidict: 6.0.5 - multiprocess: 0.70.16 - narwhals: 1.8.2 - nest-asyncio: 1.6.0 - networkx: 3.3 - numpy: 1.26.4 - omegaconf: 2.3.0 - open-clip-torch: 2.26.1 - opencv-python: 4.10.0 - opencv-python-headless: 4.10.0 - ordered-set: 4.1.0 - packaging: 24.1 - pandas: 2.2.2 - parso: 0.8.4 - pathtools: 0.1.2 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 10.4.0 - pip: 24.2 - pkgutil-resolve-name: 1.3.10 - platformdirs: 4.3.6 - prompt-toolkit: 3.0.47 - protobuf: 4.25.3 - psutil: 6.0.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.3 - pyarrow: 17.0.0 - pyarrow-hotfix: 0.6 - pycparser: 2.22 - pydeck: 0.8.0b4 - pygments: 2.18.0 - pysocks: 1.7.1 - python-dateutil: 2.9.0 - pytorch-lightning: 2.4.0 - pytz: 2024.1 - pyyaml: 6.0.2 - pyzmq: 26.2.0 - referencing: 0.35.1 - regex: 2024.7.24 - requests: 2.32.3 - rich: 13.8.1 - rpds-py: 0.20.0 - safetensors: 0.4.4 - sentry-sdk: 2.12.0 - setproctitle: 1.3.3 - setuptools: 72.1.0 - six: 1.16.0 - smmap: 5.0.0 - stack-data: 0.6.2 - streamlit: 1.38.0 - sympy: 1.13.2 - tenacity: 8.5.0 - timm: 1.0.8 - tokenizers: 0.19.1 - toml: 0.10.2 - tomli: 2.0.1 - torch: 2.1.0 - torchaudio: 2.1.0 - torchmetrics: 1.4.0.post0 - torchvision: 0.16.0 - tornado: 6.4.1 - tqdm: 4.66.5 - traitlets: 5.14.3 - transformers: 4.44.2 - triton: 2.1.0 - typeguard: 4.3.0 - typing-extensions: 4.12.2 - tzdata: 2024.1 - tzlocal: 5.2 - urllib3: 2.2.2 - validators: 0.34.0 - wandb: 0.16.6 - watchdog: 4.0.1 - wcwidth: 0.2.13 - wheel: 0.44.0 - widgetsnbextension: 4.0.13 - xformers: 0.0.22.post7 - xxhash: 3.4.1 - yarl: 1.9.4 - zipp: 3.20.2 - zstandard: 0.23.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.14 - release: 4.15.0-55-generic - version: #60-Ubuntu SMP Tue Jul 2 18:22:20 UTC 2019 ```

#- How you installed Lightning(conda, pip, source): conda

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions