Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use torch.inference_mode() in Trainer.predict #11018

Closed
davidegraff opened this issue Dec 9, 2021 · 3 comments · Fixed by #12715
Closed

use torch.inference_mode() in Trainer.predict #11018

davidegraff opened this issue Dec 9, 2021 · 3 comments · Fixed by #12715
Labels
feature Is an improvement or enhancement

Comments

@davidegraff
Copy link

davidegraff commented Dec 9, 2021

🚀 Feature

pytorch 1.9 introduced a new decorator/context manager specifically for model inference torch.inference_mode(). Like torch.no_grad() it disables gradient tracking for faster forward passes through a model and is, as the name implies, useful for model inference. Unlike torch.no_grad(), it disables view tracking, so computations made using torch.inference_mode() can not later be used in computations that require gradients (as opposed to torch.no_grad()). Disabling view tracking further speeds up model forward passes.

Motivation

The motivation here is to implement torch.inference_mode() inside the Trainer.predict() logic, replacing any occurrences of torch.no_grad(). Because this code right now is contextually used for inference, I don't believe the lack of view tracking represents a significant drawback/limitation for downstream code. Conceptually, clients of Trainer.predict() should not be relying on these tensors for downstream code that requires gradient calculation. However, clients of this code are likely relying on pytorch lightning to transparently scale and speed up their own code, and using torch.inference_mode() will further this goal.

Pitch

change

def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
        self.reset_predict_dataloader(self.lightning_module)
        # reset trainer on this loop and all child loops in case user connected a custom loop
        self.predict_loop.trainer = self
        with torch.no_grad():
            return self.predict_loop.run()

to

def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
        self.reset_predict_dataloader(self.lightning_module)
        # reset trainer on this loop and all child loops in case user connected a custom loop
        self.predict_loop.trainer = self
        with torch.inference_mode():
            return self.predict_loop.run()

I think this would be the only change to make

cc @Borda

@davidegraff davidegraff added the feature Is an improvement or enhancement label Dec 9, 2021
@ananthsub
Copy link
Contributor

ananthsub commented Dec 9, 2021

Here is a prior issue for this: #8499
We added initial support (#8813) but unfortunately, we needed to revert the PR because certain collective communications libraries didn't support this: #9431

If gloo now works with inference mode, we could try adding this back

@davidegraff
Copy link
Author

Ah cool thanks for the info. I’ll close this to avoid the duplicate issue. Thanks for all the great work!

@ananthsub
Copy link
Contributor

@davidegraff - if you don't happen to use gloo, you can still do this to get inference mode benefits:

with torch.inference_mode():
    trainer.predict(...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants