Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning 2.x support #149

Open
maxwelltsai opened this issue Sep 18, 2024 · 0 comments
Open

Lightning 2.x support #149

maxwelltsai opened this issue Sep 18, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@maxwelltsai
Copy link
Collaborator

Hi @cweniger et al.,

I notice that currently swyft is based on pytorch-lightning==1.9.5. Do you have any plans to upgrade this legacy component to a new version, e.g.,lightning==2.4.x?

A bit of background: In a collaboration with DAMTP (Cambridge) I am currently porting the swyft library to Intel GPU, because their supercomputer "Dawn" is powered by Intel GPUs. Together with my colleagues we have made a version of lightning that supports Intel GPUs, but it is based on lightning 2.x. Therefore, I made some changes in the swyft code to bump up lightning to 2.4, but it seems that swyft relies on an API that is no longer available in 2.x. The following error occurs when I try to do trainer.infer(network, obs, prior_samples):

Traceback (most recent call last):
  File "/nfs/site/home/xucai/Works/swyft/tests/truncation.py", line 78, in <module>
    predictions, bounds, samples = round(obs, bounds = bounds)
  File "/nfs/site/home/xucai/Works/swyft/tests/truncation.py", line 68, in round
    predictions = trainer.infer(network, obs, prior_samples)
  File "/nfs/site/home/xucai/Works/swyft/swyft/lightning/core.py", line 318, in infer
    ratio_batches = self.predict(model, dl)
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 858, in predict
    return call._call_and_handle_interrupt(
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 897, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1020, in _run_stage
    return self.predict_loop.run()
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/loops/prediction_loop.py", line 107, in run
    self.reset()
  File "/nfs/site/home/xucai/.conda/envs/swyft/lib/python3.9/site-packages/lightning/pytorch/loops/prediction_loop.py", line 176, in reset
    raise ValueError('`trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.')
ValueError: `trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.

How do we get rid of this issue? Do we need to use CombinedLoader in non-sequential mode?

If it helps, I can submit a PR with the modification that bumps up the lightning version to 2.4.

Thanks,
Maxwell

@maxwelltsai maxwelltsai added the enhancement New feature or request label Sep 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant