use torch.inference_mode()
in Trainer.predict
#11018
Labels
feature
Is an improvement or enhancement
torch.inference_mode()
in Trainer.predict
#11018
🚀 Feature
pytorch 1.9 introduced a new decorator/context manager specifically for model inference
torch.inference_mode()
. Liketorch.no_grad()
it disables gradient tracking for faster forward passes through a model and is, as the name implies, useful for model inference. Unliketorch.no_grad()
, it disables view tracking, so computations made usingtorch.inference_mode()
can not later be used in computations that require gradients (as opposed totorch.no_grad()
). Disabling view tracking further speeds up model forward passes.Motivation
The motivation here is to implement
torch.inference_mode()
inside theTrainer.predict()
logic, replacing any occurrences oftorch.no_grad()
. Because this code right now is contextually used for inference, I don't believe the lack of view tracking represents a significant drawback/limitation for downstream code. Conceptually, clients ofTrainer.predict()
should not be relying on these tensors for downstream code that requires gradient calculation. However, clients of this code are likely relying on pytorch lightning to transparently scale and speed up their own code, and usingtorch.inference_mode()
will further this goal.Pitch
change
to
I think this would be the only change to make
cc @Borda
The text was updated successfully, but these errors were encountered: